layers.py 57.9 KB
Newer Older
1
2
3
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
6
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

11
from vllm.adapter_commons.layers import AdapterMapping
12
from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
from vllm.distributed.utils import divide
20
from vllm.lora.punica import PunicaWrapper
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22
                                               MergedColumnParallelLinear,
23
                                               QKVParallelLinear,
24
                                               ReplicatedLinear,
25
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
28
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    VocabParallelEmbedding)
31
32
33
34
35

if TYPE_CHECKING:
    pass


36
37
38
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
39
    # unquantizedLinear
40
41
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
42
43
44
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
45
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
46
47
48
49
50
51
52
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
53
54


55
56
57
58
59
60
61
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
62
63
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
64
65
66
67
68
69
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def apply_bias(
    indices: torch.Tensor,
    output: torch.Tensor,
    bias_stacked: torch.Tensor,
):
    """Applies bias to output

    Input shapes:
        bias_stacked:    (num_loras, output_dim)
        indices:         (batch_size)
        output:          (batch_size, output_dim)
    """
    org_output = output
    output = output.view(-1, output.shape[-1])
    indices = indices.view(-1)

    bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
    bias_stacked = bias_stacked[indices]
    bias_stacked[indices == -1] = 0
    output += bias_stacked

    return output.view_as(org_output)


def apply_bias_packed_nslice(
    indices: torch.Tensor,
    output: torch.Tensor,
    output_slices: Tuple[int, ...],
    bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
    """Applies bias to output

    Input shapes:
        bias_stacked:      3 element tuple of (num_loras, output_dim)
        indices:           (batch_size)
        output:            (batch_size, q_slice_size + 2*kv_slice_size)
        output_slices:     n-1 element tuple of (slice_size...),
                           where n is number of slices
    """
    org_output = output
    output = output.view(-1, output.shape[-1])
    indices = indices.view(-1)

    offset_left = 0
    for slice_idx, slice in enumerate(output_slices):
        bias = bias_stacked[slice_idx]
        if bias is not None:
            bias = bias.view(-1, bias.shape[-1])
            bias = bias[indices]
            bias[indices == -1] = 0
            output[:, offset_left:offset_left + slice] += bias

        offset_left += slice

    return output.view_as(org_output)


127
@dataclass
128
class LoRAMapping(AdapterMapping):
129
    is_prefill: bool = False
130
131
132
133


class BaseLayerWithLoRA(nn.Module):

134
135
136
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
137
138
139
        """Slice lora a if splitting for tensor parallelism."""
        ...

140
141
142
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
143
144
145
        """Slice lora b if splitting with tensor parallelism."""
        ...

146
    def create_lora_weights(
147
148
149
150
151
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
152
153
154
155
156
157
158
159
160
161
162
163
164
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
165
        bias: Optional[torch.Tensor] = None,
166
167
168
169
170
171
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
172
        punica_wrapper: PunicaWrapper,
173
    ):
174
        self.punica_wrapper: PunicaWrapper = punica_wrapper
175

176
    @classmethod
177
178
179
180
181
182
183
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
184
185
186
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

187
188
189
190
191
192

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
193
194
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
195
196
197
198
199
200
201

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

202
        if self.base_layer.num_added_embeddings_per_partition > 0:
203
            # We can start adding lora weights
204
205
206
207
208
209
210
211
212
213
214
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
264
        bias: Optional[torch.Tensor] = None,
265
266
267
268
269
270
271
272
273
274
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
275
                shape[1], ].copy_(embeddings_tensor, non_blocking=True)
276
277
278
279
280
281
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
282
                    self.embeddings_tensors.shape[2],
283
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
284
                assert self.embeddings_weights is not None
285
286
287
288
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
289
290
        embeddings_indices = self.punica_wrapper.embeddings_indices
        indices = embeddings_indices[1].view_as(x)
291
292
293
294
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
295
        indices = embeddings_indices[0].view_as(x)
296
297
298
299
300
301
302
303
304
305
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
306
307
308
309
310
311
312
313
314
                full_lora_a_embeddings.shape[1],
                -1,
            )

        # Embedding layer only need expand op
        self.punica_wrapper.add_expand(full_output,
                                       full_lora_a_embeddings,
                                       self.lora_b_stacked,
                                       add_input=True)
315
316
        return full_output.view_as(full_output_org)

317
    @classmethod
318
319
320
321
322
323
324
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
325
326
        return type(source_layer) is VocabParallelEmbedding

327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
        lora_a_output_size = lora_config.max_lora_rank
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
            lora_a_output_size,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
            self.output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
361
362
363
364
365
366
367
368
369
370
        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                max_loras,
                1,
                self.output_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None
371
372
373
374

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
375
376
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
377
378
379
380
381
382
383

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
384
        bias: Optional[torch.Tensor] = None,
385
386
387
388
389
390
391
392
393
    ):
        self.reset_lora(index)

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
394
395
396
397
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
398
399
400
401

    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
402
403
404
405
406
407
408
        if self.bias_stacked is not None:
            self.indices = self.punica_wrapper.token_lora_indices
            output = apply_bias(
                self.indices,
                output,
                self.bias_stacked,
            )
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
        return output

    def forward(self, input_):
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is ReplicatedLinear


445
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
446
447
    """
    LoRA on top of ColumnParallelLinear layer.
448

449
450
    LoRA B is sliced for tensor parallelism.
    """
451
452
453

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__()
454
455
456
457
458
459
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear

460
        self.base_layer = base_layer
461
        self.tp_size = get_tensor_model_parallel_world_size()
462
463
464
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size_per_partition
        self.device = _get_lora_device(self.base_layer)
465
466

    def create_lora_weights(
467
468
469
470
471
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
472
473
474
475
476
        self.lora_config = lora_config
        self.tp_size = get_tensor_model_parallel_world_size()
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
477
478
479
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
480
            lora_a_output_size_per_partition,
481
            self.input_size,
482
            dtype=lora_config.lora_dtype,
483
            device=self.device,
484
485
486
487
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
488
            self.output_size,
489
490
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
491
            device=self.device,
492
        )
493
494
495
496
497
498
499
500
501
502
503
504

        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                max_loras,
                1,
                self.output_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None

505
        self.output_dim = self.lora_b_stacked.shape[2]
506
507
508
509

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
510
511
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
512

513
514
515
516
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_dim
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
537
538
        return lora_b

539
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
540
        # TODO: Fix the slicing logic of bias.
541
542
543
544
545
546
547
548
549
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.output_dim
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

550
551
552
553
554
555
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
556
        bias: Optional[torch.Tensor] = None,
557
558
    ):
        self.reset_lora(index)
559

560
        if self.tp_size > 1:
561
562
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
563
            bias = self.slice_bias(bias)
564

565
566
567
568
569
570
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
571
572
573
574
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
575

576
577
578
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
579
580
581
582
583
584
585
        if self.bias_stacked is not None:
            self.indices = self.punica_wrapper.token_lora_indices
            output = apply_bias(
                self.indices,
                output,
                self.bias_stacked,
            )
586
587
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        return output

    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
604
        output_parallel = self.apply(input_, bias)
605
606
607
608
609
610
611
612
613
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

614
    @classmethod
615
    @_not_fully_sharded_can_replace
616
617
618
619
620
621
622
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
623
624
625
626
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

627
628
629
630
631
632
633
634
635
636
637
638
639
640

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
641
642
643
644
645
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
646
        self.lora_config = lora_config
647
648
649
650
651
652
653
654
        n_slices = 2
        if not (len(self.base_layer.output_sizes) == n_slices
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()
655
656
657
658
659
        self.tp_rank = get_tensor_model_parallel_rank()

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
660
661
662
663
664

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
665
                lora_a_output_size_per_partition,
666
                self.input_size,
667
                dtype=lora_config.lora_dtype,
668
                device=self.device,
669
670
671
672
673
            ) for _ in range(n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
674
                self.output_size // 2,
675
676
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
677
                device=self.device,
678
            ) for _ in range(n_slices))
679
680
681
682
683
684
685
686
687
688
689
        if lora_config.bias_enabled:
            self.bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    self.output_size // 2,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(n_slices))
        else:
            self.bias_stacked = None
690
691
692
693
694
695
696
697

        self.output_dim = self.lora_b_stacked[0].shape[2]

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_b_stacked[1][index] = 0
698
699
700
        if self.lora_config.bias_enabled:
            self.bias_stacked[0][index] = 0
            self.bias_stacked[1][index] = 0
701

702
703
704
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
705
706
        return lora_a

707
708
709
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
710
        #NOTE: lora_b contains 2 subloras, and each sublora could be None.
711
712
713
714
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = [
715
716
            lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
            lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
717
718
719
        ]
        return lora_b

720
721
722
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
723
        # NOTE : each bias could be None.
724
725
726
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
727
728
729
730
        bias = [
            bias[0][start_idx:end_idx] if bias[0] is not None else None,
            bias[1][start_idx:end_idx] if bias[1] is not None else None
        ]
731
732
        return bias

733
734
735
736
737
738
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
739
        bias: Optional[torch.Tensor] = None,
740
741
742
743
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
744
745
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
746
747
            if bias is not None:
                bias = self.slice_bias(bias)
748
749
750
751
752
753
754
755

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
756
757
758
759
        if bias is not None and bias[0] is not None:
            self.bias_stacked[0][index,
                                 0, :bias[0].shape[0]].copy_(bias[0].T,
                                                             non_blocking=True)
760
761
762
763
764
765
766
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)
767
768
769
770
        if bias is not None and bias[1] is not None:
            self.bias_stacked[1][index,
                                 0, :bias[1].shape[0]].copy_(bias[1].T,
                                                             non_blocking=True)
771

772
773
774
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
775
776
777
778
779
780
781
782
        if self.bias_stacked is not None:
            self.indices = self.punica_wrapper.token_lora_indices
            output = apply_bias_packed_nslice(
                self.indices,
                output,
                (self.output_dim, self.output_dim),
                self.bias_stacked,
            )
783
784
785
        self.punica_wrapper.add_lora_packed_nslice(
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
            (self.output_dim, self.output_dim))
786
787
        return output

788
    @classmethod
789
    @_not_fully_sharded_can_replace
790
791
792
793
794
795
796
797
798
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is MergedColumnParallelLinear
                and len(packed_modules_list) == 2)
799

800
801

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
802
    """
803
    ColumnParallelLinear layer that is specifically designed for
804
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
805
    only contains a single LoRA within their qkv_proj layer.
806

807
    During inference with Tensor Parallel, the weights of lora_b
808
    must be accurately partitioned according to the respective ranks.
809

810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

859
860
861
862
863
864
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
865
        bias: Optional[torch.Tensor] = None,
866
867
868
    ):
        self.reset_lora(index)
        if self.tp_size > 1:
869
870
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
871
872
            if bias is not None:
                bias = self.slice_bias(bias)
873
874
875
876
877
878
879

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
880
881
882
883
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
884
885

    @classmethod
886
    @_not_fully_sharded_can_replace
887
888
889
890
891
892
893
894
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
895
896
897
898
899
900
901
902
903
904
905
906
907
908
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
909
910
911
912
913
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
914
        self.lora_config = lora_config
915
        self.tp_size = get_tensor_model_parallel_world_size()
916
        self.tp_rank = get_tensor_model_parallel_rank()
917
918
919
920
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
921
922
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
923

924
925
926
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
927
928
929
930
931
        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
932
                lora_a_output_size_per_partition,
933
                self.input_size,
934
                dtype=lora_config.lora_dtype,
935
                device=self.device,
936
937
938
939
            ),
            torch.zeros(
                max_loras,
                1,
940
                lora_a_output_size_per_partition,
941
                self.input_size,
942
                dtype=lora_config.lora_dtype,
943
                device=self.device,
944
945
946
947
            ),
            torch.zeros(
                max_loras,
                1,
948
                lora_a_output_size_per_partition,
949
                self.input_size,
950
                dtype=lora_config.lora_dtype,
951
                device=self.device,
952
953
954
955
956
957
958
959
960
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
961
                device=self.device,
962
963
964
965
966
967
968
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
969
                device=self.device,
970
971
972
973
974
975
976
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
977
                device=self.device,
978
979
            ),
        )
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        if lora_config.bias_enabled:
            self.bias_stacked = (
                torch.zeros(
                    max_loras,
                    1,
                    self.q_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
            )
        else:
            self.bias_stacked = None
1006

1007
1008
1009
1010
1011
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
1012
1013
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
1014
        # lazily initialized.
1015
        self.indices: torch.Tensor
1016
        self.indices_len: List[int]
1017
1018
1019
1020
1021
1022
1023
1024

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[1][index] = 0
        self.lora_a_stacked[2][index] = 0
        self.lora_b_stacked[2][index] = 0
1025
1026
1027
1028
        if self.lora_config.bias_enabled:
            self.bias_stacked[0][index] = 0
            self.bias_stacked[1][index] = 0
            self.bias_stacked[2][index] = 0
1029

1030
1031
1032
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
1033
1034
        return lora_a

1035
1036
1037
1038
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        lora_b_q, lora_b_k, lora_b_v = None, None, None
1039
1040
1041
        if lora_b[0] is not None:
            lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                 self.q_shard_id:self.q_proj_shard_size *
1042
                                 (self.q_shard_id + 1), ]
1043
1044
1045
        if lora_b[1] is not None:
            lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
1046
                                 (self.kv_shard_id + 1), ]
1047
1048
1049
        if lora_b[2] is not None:
            lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
1050
                                 (self.kv_shard_id + 1), ]
1051
1052
1053
        lora_b = [lora_b_q, lora_b_k, lora_b_v]
        return lora_b

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
        bias_q, bias_k, bias_v = bias
        if bias_q is not None:
            bias_q = bias_q[self.q_proj_shard_size *
                            self.q_shard_id:self.q_proj_shard_size *
                            (self.q_shard_id + 1)]
        if bias_k is not None:
            bias_k = bias_k[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        if bias_v is not None:
            bias_v = bias_v[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        bias = [bias_q, bias_k, bias_v]
        return bias

1073
1074
1075
1076
1077
1078
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1079
        bias: Optional[torch.Tensor] = None,
1080
1081
1082
1083
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
1084
1085
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
1086
1087
            if bias is not None:
                bias = self.slice_bias(bias)
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103

        if lora_b[0] is not None:
            lora_b_q = lora_b[0]
            self.lora_b_stacked[0][
                index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                    lora_b_q.T, non_blocking=True)
        if lora_b[1] is not None:
            lora_b_k = lora_b[1]
            self.lora_b_stacked[1][
                index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                    lora_b_k.T, non_blocking=True)
        if lora_b[2] is not None:
            lora_b_v = lora_b[2]
            self.lora_b_stacked[2][
                index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                    lora_b_v.T, non_blocking=True)
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
        if lora_a[2] is not None:
            self.lora_a_stacked[2][
                index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                    lora_a[2].T, non_blocking=True)

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
        if bias is not None:
            if bias[0] is not None:
                self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
                    bias[0].T, non_blocking=True)
            if bias[1] is not None:
                self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
                    bias[1].T, non_blocking=True)
            if bias[2] is not None:
                self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
                    bias[2].T, non_blocking=True)

1129
1130
1131
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
1132
1133
1134
1135
1136
1137
1138
1139
        if self.bias_stacked is not None:
            self.indices = self.punica_wrapper.token_lora_indices
            output = apply_bias_packed_nslice(
                self.indices,
                output,
                self.output_slices,
                self.bias_stacked,
            )
1140
1141
1142
1143
        self.punica_wrapper.add_lora_packed_nslice(output, x,
                                                   self.lora_a_stacked,
                                                   self.lora_b_stacked, 1.0,
                                                   self.output_slices)
1144
1145
        return output

1146
    @classmethod
1147
    @_not_fully_sharded_can_replace
1148
1149
1150
1151
1152
1153
1154
1155
1156
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is QKVParallelLinear
                and len(packed_modules_list) == 3)
1157

1158
1159
1160
1161
1162
1163

class RowParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
1164
1165
1166
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)
1167
1168

    def create_lora_weights(
1169
1170
1171
1172
1173
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1174
1175
        self.lora_config = lora_config
        self.tp_rank = get_tensor_model_parallel_rank()
1176
1177
1178
1179
1180
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
1181
                self.input_size,
1182
1183
            ),
            dtype=lora_config.lora_dtype,
1184
            device=self.device,
1185
        )
1186
1187
1188
1189
1190
        tp_size = get_tensor_model_parallel_world_size()
        lora_b_output_size_per_partition = (
            self.output_size if not lora_config.fully_sharded_loras else
            divide(self.output_size, tp_size))

1191
1192
1193
1194
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
1195
                lora_b_output_size_per_partition,
1196
1197
1198
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
1199
            device=self.device,
1200
1201
        )

1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                (
                    max_loras,
                    1,
                    self.output_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None
        # Lazily initialized
        self.indices: torch.Tensor
        self.indices_len: List[int]

1218
1219
1220
    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
1221
1222
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
1223

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.input_size
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

1235
1236
1237
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

1238
1239
1240
1241
1242
1243
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1244
        bias: Optional[torch.Tensor] = None,
1245
1246
    ):
        self.reset_lora(index)
1247

1248
        if self.base_layer.tp_size > 1:
1249
1250
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
1251
1252
            if bias is not None:
                bias = self.slice_bias(bias)
1253
1254
1255
1256
1257
1258
1259

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
1260
1261
1262
1263
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
1264

1265
1266
    def apply(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x)
1267
1268
1269
1270
1271
1272
1273
        if self.bias_stacked is not None:
            self.indices = self.punica_wrapper.token_lora_indices
            output = apply_bias(
                self.indices,
                output,
                self.bias_stacked,
            )
1274
1275
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        return output

    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1301
        output_parallel = self.apply(input_parallel)
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
1318
1319
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
1320

1321
    @classmethod
1322
    @_not_fully_sharded_can_replace
1323
1324
1325
1326
1327
1328
1329
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1330
1331
        return type(source_layer) is RowParallelLinear

1332

1333
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
1347

1348
1349
1350
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
1351
1352
1353
1354
1355
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
1356
1357
1358
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
1359

1360
    @property
1361
1362
    def logits_as_input(self):
        return self.base_layer.logits_as_input
1363

1364
1365
1366
1367
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1368
1369
1370
1371
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1372
1373
1374
1375
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1376
1377
1378
1379
    @property
    def use_gather(self):
        return self.base_layer.use_gather

1380
1381
1382
1383
1384
1385
1386
1387
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1388
1389
1390
1391
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1392
1393
1394
1395
1396
1397
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1398
1399
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1400
            raise ValueError("When using LoRA, vocab size must be "
1401
                             "32000 >= vocab_size <= 257024")
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1431
1432
1433
1434
1435
1436
1437
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1450
        bias: Optional[torch.Tensor] = None,
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1467
        lm_head: VocabParallelEmbedding,
1468
        embedding_bias: Optional[torch.Tensor] = None,
1469
    ) -> Optional[torch.Tensor]:
1470
        # Get the logits for the next tokens.
1471
        logits = lm_head.linear_method.apply(lm_head, hidden_states)
1472
1473
1474
1475
1476
1477
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1509
        indices_padded = self.punica_wrapper.sampler_indices_padded
1510
1511
1512
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1513
1514
1515
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1516
1517
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1518
               lora_logits.shape[1]] = lora_logits
1519
1520
1521
1522
1523

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1524
1525
1526
1527
1528
1529
1530
1531

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1532
    @classmethod
1533
1534
1535
1536
1537
1538
1539
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1540
1541
        # Special handling for the LogitsProcessor.
        return False
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569


class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1570
1571
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1595
        bias: Optional[torch.Tensor] = None,
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1609
1610
            offsets=self.punica_wrapper.long_lora_indices,
        )
1611
1612
1613
1614
1615
1616

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1617
1618
1619
1620
1621
1622
1623
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1624
        """Returns True if the layer can be replaced by this LoRA layer."""
1625
1626
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1627
1628
1629

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()