layers.py 55.3 KB
Newer Older
1
2
3
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
6
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

11
from vllm.adapter_commons.layers import AdapterMapping
12
from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
from vllm.distributed.utils import divide
20
from vllm.lora.punica import PunicaWrapper
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22
                                               MergedColumnParallelLinear,
23
                                               QKVParallelLinear,
24
                                               ReplicatedLinear,
25
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
28
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    VocabParallelEmbedding)
31
32
33
34
35

if TYPE_CHECKING:
    pass


36
37
38
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
39
    # unquantizedLinear
40
41
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
42
43
44
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
45
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
46
47
48
49
50
51
52
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
53
54


55
56
57
58
59
60
61
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
62
63
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
64
65
66
67
68
69
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


70
@dataclass
71
class LoRAMapping(AdapterMapping):
72
    is_prefill: bool = False
73
74
75
76


class BaseLayerWithLoRA(nn.Module):

77
78
79
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
80
81
82
        """Slice lora a if splitting for tensor parallelism."""
        ...

83
84
85
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
86
87
88
        """Slice lora b if splitting with tensor parallelism."""
        ...

89
    def create_lora_weights(
90
91
92
93
94
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
95
96
97
98
99
100
101
102
103
104
105
106
107
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
108
        bias: Optional[torch.Tensor] = None,
109
110
111
112
113
114
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
115
        punica_wrapper: PunicaWrapper,
116
    ):
117
        self.punica_wrapper: PunicaWrapper = punica_wrapper
118

119
    @classmethod
120
121
122
123
124
125
126
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
127
128
129
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

130
131
132
133
134
135

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
136
137
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
138
139
140
141
142
143
144

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

145
        if self.base_layer.num_added_embeddings_per_partition > 0:
146
            # We can start adding lora weights
147
148
149
150
151
152
153
154
155
156
157
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
207
        bias: Optional[torch.Tensor] = None,
208
209
210
211
212
213
214
215
216
217
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
218
                shape[1], ].copy_(embeddings_tensor, non_blocking=True)
219
220
221
222
223
224
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
225
                    self.embeddings_tensors.shape[2],
226
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
227
                assert self.embeddings_weights is not None
228
229
230
231
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
232
233
        embeddings_indices = self.punica_wrapper.embeddings_indices
        indices = embeddings_indices[1].view_as(x)
234
235
236
237
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
238
        indices = embeddings_indices[0].view_as(x)
239
240
241
242
243
244
245
246
247
248
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
249
250
251
252
253
254
255
256
                full_lora_a_embeddings.shape[1],
                -1,
            )

        # Embedding layer only need expand op
        self.punica_wrapper.add_expand(full_output,
                                       full_lora_a_embeddings,
                                       self.lora_b_stacked,
257
                                       bias_all=None,
258
                                       add_input=True)
259
260
        return full_output.view_as(full_output_org)

261
    @classmethod
262
263
264
265
266
267
268
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
269
270
        return type(source_layer) is VocabParallelEmbedding

271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
        lora_a_output_size = lora_config.max_lora_rank
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
            lora_a_output_size,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
            self.output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
305
306
307
308
309
310
311
312
313
314
        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                max_loras,
                1,
                self.output_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None
315
316
317
318

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
319
320
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
321
322
323
324
325
326
327

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
328
        bias: Optional[torch.Tensor] = None,
329
330
331
332
333
334
335
336
337
    ):
        self.reset_lora(index)

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
338
339
340
341
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
342
343
344
345
346

    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
347
348
                                     self.lora_b_stacked, self.bias_stacked,
                                     1.0)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        return output

    def forward(self, input_):
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is ReplicatedLinear


383
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
384
385
    """
    LoRA on top of ColumnParallelLinear layer.
386

387
388
    LoRA B is sliced for tensor parallelism.
    """
389
390
391

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__()
392
393
394
395
396
397
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear

398
        self.base_layer = base_layer
399
        self.tp_size = get_tensor_model_parallel_world_size()
400
401
402
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size_per_partition
        self.device = _get_lora_device(self.base_layer)
403
404

    def create_lora_weights(
405
406
407
408
409
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
410
411
412
413
414
        self.lora_config = lora_config
        self.tp_size = get_tensor_model_parallel_world_size()
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
415
416
417
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
418
            lora_a_output_size_per_partition,
419
            self.input_size,
420
            dtype=lora_config.lora_dtype,
421
            device=self.device,
422
423
424
425
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
426
            self.output_size,
427
428
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
429
            device=self.device,
430
        )
431
432
433
434
435
436
437
438
439
440
441
442

        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                max_loras,
                1,
                self.output_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None

443
        self.output_dim = self.lora_b_stacked.shape[2]
444
445
446
447

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
448
449
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
450

451
452
453
454
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_dim
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
475
476
        return lora_b

477
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
478
        # TODO: Fix the slicing logic of bias.
479
480
481
482
483
484
485
486
487
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.output_dim
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

488
489
490
491
492
493
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
494
        bias: Optional[torch.Tensor] = None,
495
496
    ):
        self.reset_lora(index)
497

498
        if self.tp_size > 1:
499
500
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
501
            bias = self.slice_bias(bias)
502

503
504
505
506
507
508
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
509
510
511
512
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
513

514
515
516
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
517
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
518
519
                                     self.lora_b_stacked, self.bias_stacked,
                                     1.0)
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        return output

    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
536
        output_parallel = self.apply(input_, bias)
537
538
539
540
541
542
543
544
545
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

546
    @classmethod
547
    @_not_fully_sharded_can_replace
548
549
550
551
552
553
554
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
555
556
557
558
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

559
560
561
562
563
564
565
566
567
568
569
570
571
572

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
573
574
575
576
577
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
578
        self.lora_config = lora_config
579
580
581
582
583
584
585
586
        n_slices = 2
        if not (len(self.base_layer.output_sizes) == n_slices
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()
587
588
589
590
591
        self.tp_rank = get_tensor_model_parallel_rank()

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
592
593
594
595
596

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
597
                lora_a_output_size_per_partition,
598
                self.input_size,
599
                dtype=lora_config.lora_dtype,
600
                device=self.device,
601
602
603
604
605
            ) for _ in range(n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
606
                self.output_size // 2,
607
608
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
609
                device=self.device,
610
            ) for _ in range(n_slices))
611
612
613
614
615
616
617
618
619
620
621
        if lora_config.bias_enabled:
            self.bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    self.output_size // 2,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(n_slices))
        else:
            self.bias_stacked = None
622
        self.output_dim = self.lora_b_stacked[0].shape[2]
623
        self.output_slices = (self.output_dim, self.output_dim)
624
625
626
627
628
629

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_b_stacked[1][index] = 0
630
631
632
        if self.lora_config.bias_enabled:
            self.bias_stacked[0][index] = 0
            self.bias_stacked[1][index] = 0
633

634
635
636
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
637
638
        return lora_a

639
640
641
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
642
        #NOTE: lora_b contains 2 subloras, and each sublora could be None.
643
644
645
646
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = [
647
648
            lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
            lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
649
650
651
        ]
        return lora_b

652
653
654
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
655
        # NOTE : each bias could be None.
656
657
658
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
659
660
661
662
        bias = [
            bias[0][start_idx:end_idx] if bias[0] is not None else None,
            bias[1][start_idx:end_idx] if bias[1] is not None else None
        ]
663
664
        return bias

665
666
667
668
669
670
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
671
        bias: Optional[torch.Tensor] = None,
672
673
674
675
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
676
677
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
678
679
            if bias is not None:
                bias = self.slice_bias(bias)
680
681
682
683
684
685
686
687

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
688
689
690
691
        if bias is not None and bias[0] is not None:
            self.bias_stacked[0][index,
                                 0, :bias[0].shape[0]].copy_(bias[0].T,
                                                             non_blocking=True)
692
693
694
695
696
697
698
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)
699
700
701
702
        if bias is not None and bias[1] is not None:
            self.bias_stacked[1][index,
                                 0, :bias[1].shape[0]].copy_(bias[1].T,
                                                             non_blocking=True)
703

704
705
706
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
707
        self.punica_wrapper.add_lora_packed_nslice(
708
709
            output, x, self.lora_a_stacked, self.lora_b_stacked,
            self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
710
711
        return output

712
    @classmethod
713
    @_not_fully_sharded_can_replace
714
715
716
717
718
719
720
721
722
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is MergedColumnParallelLinear
                and len(packed_modules_list) == 2)
723

724
725

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
726
    """
727
    ColumnParallelLinear layer that is specifically designed for
728
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
729
    only contains a single LoRA within their qkv_proj layer.
730

731
    During inference with Tensor Parallel, the weights of lora_b
732
    must be accurately partitioned according to the respective ranks.
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

783
784
785
786
787
788
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
789
        bias: Optional[torch.Tensor] = None,
790
791
792
    ):
        self.reset_lora(index)
        if self.tp_size > 1:
793
794
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
795
796
            if bias is not None:
                bias = self.slice_bias(bias)
797
798
799
800
801
802
803

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
804
805
806
807
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
808
809

    @classmethod
810
    @_not_fully_sharded_can_replace
811
812
813
814
815
816
817
818
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
833
834
835
836
837
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
838
        self.lora_config = lora_config
839
        self.tp_size = get_tensor_model_parallel_world_size()
840
        self.tp_rank = get_tensor_model_parallel_rank()
841
842
843
844
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
845
846
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
847

848
849
850
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
851
852
853
854
855
        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
856
                lora_a_output_size_per_partition,
857
                self.input_size,
858
                dtype=lora_config.lora_dtype,
859
                device=self.device,
860
861
862
863
            ),
            torch.zeros(
                max_loras,
                1,
864
                lora_a_output_size_per_partition,
865
                self.input_size,
866
                dtype=lora_config.lora_dtype,
867
                device=self.device,
868
869
870
871
            ),
            torch.zeros(
                max_loras,
                1,
872
                lora_a_output_size_per_partition,
873
                self.input_size,
874
                dtype=lora_config.lora_dtype,
875
                device=self.device,
876
877
878
879
880
881
882
883
884
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
885
                device=self.device,
886
887
888
889
890
891
892
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
893
                device=self.device,
894
895
896
897
898
899
900
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
901
                device=self.device,
902
903
            ),
        )
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
        if lora_config.bias_enabled:
            self.bias_stacked = (
                torch.zeros(
                    max_loras,
                    1,
                    self.q_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
            )
        else:
            self.bias_stacked = None
930

931
932
933
934
935
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
936
937
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
938
        # lazily initialized.
939
        self.indices: torch.Tensor
940
        self.indices_len: List[int]
941
942
943
944
945
946
947
948

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[1][index] = 0
        self.lora_a_stacked[2][index] = 0
        self.lora_b_stacked[2][index] = 0
949
950
951
952
        if self.lora_config.bias_enabled:
            self.bias_stacked[0][index] = 0
            self.bias_stacked[1][index] = 0
            self.bias_stacked[2][index] = 0
953

954
955
956
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
957
958
        return lora_a

959
960
961
962
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        lora_b_q, lora_b_k, lora_b_v = None, None, None
963
964
965
        if lora_b[0] is not None:
            lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                 self.q_shard_id:self.q_proj_shard_size *
966
                                 (self.q_shard_id + 1), ]
967
968
969
        if lora_b[1] is not None:
            lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
970
                                 (self.kv_shard_id + 1), ]
971
972
973
        if lora_b[2] is not None:
            lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
974
                                 (self.kv_shard_id + 1), ]
975
976
977
        lora_b = [lora_b_q, lora_b_k, lora_b_v]
        return lora_b

978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
        bias_q, bias_k, bias_v = bias
        if bias_q is not None:
            bias_q = bias_q[self.q_proj_shard_size *
                            self.q_shard_id:self.q_proj_shard_size *
                            (self.q_shard_id + 1)]
        if bias_k is not None:
            bias_k = bias_k[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        if bias_v is not None:
            bias_v = bias_v[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        bias = [bias_q, bias_k, bias_v]
        return bias

997
998
999
1000
1001
1002
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1003
        bias: Optional[torch.Tensor] = None,
1004
1005
1006
1007
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
1008
1009
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
1010
1011
            if bias is not None:
                bias = self.slice_bias(bias)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027

        if lora_b[0] is not None:
            lora_b_q = lora_b[0]
            self.lora_b_stacked[0][
                index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                    lora_b_q.T, non_blocking=True)
        if lora_b[1] is not None:
            lora_b_k = lora_b[1]
            self.lora_b_stacked[1][
                index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                    lora_b_k.T, non_blocking=True)
        if lora_b[2] is not None:
            lora_b_v = lora_b[2]
            self.lora_b_stacked[2][
                index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                    lora_b_v.T, non_blocking=True)
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
        if lora_a[2] is not None:
            self.lora_a_stacked[2][
                index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                    lora_a[2].T, non_blocking=True)

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        if bias is not None:
            if bias[0] is not None:
                self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
                    bias[0].T, non_blocking=True)
            if bias[1] is not None:
                self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
                    bias[1].T, non_blocking=True)
            if bias[2] is not None:
                self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
                    bias[2].T, non_blocking=True)

1053
1054
1055
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
1056
1057
        self.punica_wrapper.add_lora_packed_nslice(output, x,
                                                   self.lora_a_stacked,
1058
1059
                                                   self.lora_b_stacked,
                                                   self.bias_stacked, 1.0,
1060
                                                   self.output_slices)
1061
1062
        return output

1063
    @classmethod
1064
    @_not_fully_sharded_can_replace
1065
1066
1067
1068
1069
1070
1071
1072
1073
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is QKVParallelLinear
                and len(packed_modules_list) == 3)
1074

1075
1076
1077
1078
1079
1080

class RowParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
1081
1082
1083
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)
1084
1085

    def create_lora_weights(
1086
1087
1088
1089
1090
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1091
1092
        self.lora_config = lora_config
        self.tp_rank = get_tensor_model_parallel_rank()
1093
1094
1095
1096
1097
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
1098
                self.input_size,
1099
1100
            ),
            dtype=lora_config.lora_dtype,
1101
            device=self.device,
1102
        )
1103
1104
1105
1106
1107
        tp_size = get_tensor_model_parallel_world_size()
        lora_b_output_size_per_partition = (
            self.output_size if not lora_config.fully_sharded_loras else
            divide(self.output_size, tp_size))

1108
1109
1110
1111
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
1112
                lora_b_output_size_per_partition,
1113
1114
1115
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
1116
            device=self.device,
1117
1118
        )

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        if lora_config.bias_enabled:
            self.bias_stacked = torch.zeros(
                (
                    max_loras,
                    1,
                    self.output_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
        else:
            self.bias_stacked = None
        # Lazily initialized
        self.indices: torch.Tensor
        self.indices_len: List[int]

1135
1136
1137
    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
1138
1139
        if self.lora_config.bias_enabled:
            self.bias_stacked[index] = 0
1140

1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.input_size
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

1152
1153
1154
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

1155
1156
1157
1158
1159
1160
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1161
        bias: Optional[torch.Tensor] = None,
1162
1163
    ):
        self.reset_lora(index)
1164

1165
        if self.base_layer.tp_size > 1:
1166
1167
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
1168
1169
            if bias is not None:
                bias = self.slice_bias(bias)
1170
1171
1172
1173
1174
1175
1176

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
1177
1178
1179
1180
        if bias is not None:
            self.bias_stacked[index,
                              0, :bias.shape[0]].copy_(bias.T,
                                                       non_blocking=True)
1181

1182
1183
    def apply(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x)
1184
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
1185
1186
                                     self.lora_b_stacked, self.bias_stacked,
                                     1.0)
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        return output

    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1212
        output_parallel = self.apply(input_parallel)
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
1229
1230
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
1231

1232
    @classmethod
1233
    @_not_fully_sharded_can_replace
1234
1235
1236
1237
1238
1239
1240
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1241
1242
        return type(source_layer) is RowParallelLinear

1243

1244
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
1258

1259
1260
1261
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
1262
1263
1264
1265
1266
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
1267
1268
1269
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
1270

1271
    @property
1272
1273
    def logits_as_input(self):
        return self.base_layer.logits_as_input
1274

1275
1276
1277
1278
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1279
1280
1281
1282
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1283
1284
1285
1286
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1287
1288
1289
1290
    @property
    def use_gather(self):
        return self.base_layer.use_gather

1291
1292
1293
1294
1295
1296
1297
1298
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1299
1300
1301
1302
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1303
1304
1305
1306
1307
1308
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1309
1310
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1311
            raise ValueError("When using LoRA, vocab size must be "
1312
                             "32000 >= vocab_size <= 257024")
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1342
1343
1344
1345
1346
1347
1348
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1361
        bias: Optional[torch.Tensor] = None,
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1378
        lm_head: VocabParallelEmbedding,
1379
        embedding_bias: Optional[torch.Tensor] = None,
1380
    ) -> Optional[torch.Tensor]:
1381
        # Get the logits for the next tokens.
1382
        logits = lm_head.linear_method.apply(lm_head, hidden_states)
1383
1384
1385
1386
1387
1388
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1420
        indices_padded = self.punica_wrapper.sampler_indices_padded
1421
1422
1423
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1424
1425
1426
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1427
1428
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1429
               lora_logits.shape[1]] = lora_logits
1430
1431
1432
1433
1434

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1435
1436
1437
1438
1439
1440
1441
1442

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1443
    @classmethod
1444
1445
1446
1447
1448
1449
1450
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1451
1452
        # Special handling for the LogitsProcessor.
        return False
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480


class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1481
1482
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1506
        bias: Optional[torch.Tensor] = None,
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1520
1521
            offsets=self.punica_wrapper.long_lora_indices,
        )
1522
1523
1524
1525
1526
1527

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1528
1529
1530
1531
1532
1533
1534
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1535
        """Returns True if the layer can be replaced by this LoRA layer."""
1536
1537
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1538
1539
1540

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()