layers.py 50.9 KB
Newer Older
1
2
3
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
5
6
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

11
from vllm.adapter_commons.layers import AdapterMapping
12
from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
from vllm.distributed.utils import divide
20
from vllm.lora.punica import PunicaWrapper
21
# yapf: disable
22
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23
                                               LinearBase,
24
                                               MergedColumnParallelLinear,
25
                                               QKVParallelLinear,
26
                                               ReplicatedLinear,
27
                                               RowParallelLinear)
28
# yapf: enable
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
31
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
    VocabParallelEmbedding)
34
35
36
37
38

if TYPE_CHECKING:
    pass


39
40
41
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
42
    # unquantizedLinear
43
44
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
45
46
47
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
48
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
49
50
51
52
53
54
55
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
56
57


58
59
60
61
62
63
64
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
65
66
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
67
68
69
70
71
72
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


73
@dataclass
74
class LoRAMapping(AdapterMapping):
75
    is_prefill: bool = False
76
77
78
79


class BaseLayerWithLoRA(nn.Module):

80
81
82
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
83
84
85
        """Slice lora a if splitting for tensor parallelism."""
        ...

86
87
88
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
89
90
91
        """Slice lora b if splitting with tensor parallelism."""
        ...

92
    def create_lora_weights(
93
94
95
96
97
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
98
99
100
101
102
103
104
105
106
107
108
109
110
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
111
        bias: Optional[torch.Tensor] = None,
112
113
114
115
116
117
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
118
        punica_wrapper: PunicaWrapper,
119
    ):
120
        self.punica_wrapper: PunicaWrapper = punica_wrapper
121

122
    @classmethod
123
124
125
126
127
128
129
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
130
131
132
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

133
134
135
136
137
138

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
139
140
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
141
142
143
144
145
146
147

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

148
        if self.base_layer.num_added_embeddings_per_partition > 0:
149
            # We can start adding lora weights
150
151
152
153
154
155
156
157
158
159
160
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
210
        bias: Optional[torch.Tensor] = None,
211
212
213
214
215
216
217
218
219
220
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
221
                shape[1], ].copy_(embeddings_tensor, non_blocking=True)
222
223
224
225
226
227
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
228
                    self.embeddings_tensors.shape[2],
229
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
230
                assert self.embeddings_weights is not None
231
232
233
234
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
235
236
        embeddings_indices = self.punica_wrapper.embeddings_indices
        indices = embeddings_indices[1].view_as(x)
237
238
239
240
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
241
        indices = embeddings_indices[0].view_as(x)
242
243
244
245
246
247
248
249
250
251
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
252
253
254
                full_lora_a_embeddings.shape[1],
                -1,
            )
255
256
257
258
        self.punica_wrapper.add_lora_embedding(full_output,
                                               full_lora_a_embeddings,
                                               self.lora_b_stacked,
                                               add_input=True)
259
260
        return full_output.view_as(full_output_org)

261
    @classmethod
262
263
264
265
266
267
268
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
269
270
        return type(source_layer) is VocabParallelEmbedding

271

272
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
273

274
    def __init__(self, base_layer: LinearBase):
275
276
277
278
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.device = _get_lora_device(self.base_layer)
279
280
281
282
283
284
        self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None

        self.output_slices: Tuple[int, ...]
        self.tp_size: int
        self.output_size: int
        self.n_slices: int
285
286
287
288
289
290
291
292

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
314
315
                max_loras,
                1,
316
317
                lora_a_out_size,
                self.input_size,
318
319
                dtype=lora_config.lora_dtype,
                device=self.device,
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )
341
342

    def reset_lora(self, index: int):
343
344
345
346
347
348
349
350
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
                self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0
351
352
353
354
355
356
357

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
358
        lora_bias: Optional[torch.Tensor] = None,
359
    ):
360
361
362
363
364
365
        # Except for QKVParallelLinearWithLora and
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                   lora_a.T, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                   lora_b.T, non_blocking=True)
        if lora_bias is not None:

            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias.T, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
391
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
392
393
394
395
        self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
                                            self.lora_b_stacked,
                                            self.lora_bias_stacked, 1.0,
                                            self.output_slices)
396
397
        return output

398
399
400
401
402
403
404
405
406
407

class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(base_layer, )
        # To ensure interface compatibility, set to 1 always.
        self.tp_size = 1
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    def forward(self, input_):
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is ReplicatedLinear


440
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
441
442
443
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
444
445
446
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
447
    """
448
449

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
450
        super().__init__(base_layer)
451
452
453
454
455
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear
456
        self.tp_size = get_tensor_model_parallel_world_size()
457
        self.output_size = self.base_layer.output_size_per_partition
458
459
        # There is only one LoRA layer
        self.n_slices = 1
460

461
462
463
464
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_dim
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
485
486
        return lora_b

487
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
488
        # TODO: Fix the slicing logic of bias.
489
490
491
492
493
494
495
496
497
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.output_dim
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

498
499
500
501
502
503
504
505
506
507
508
509
510
511
    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
512
        output_parallel = self.apply(input_, bias)
513
514
515
516
517
518
519
520
521
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

522
    @classmethod
523
    @_not_fully_sharded_can_replace
524
525
526
527
528
529
530
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
531
532
533
534
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

535
536
537
538
539
540
541
542
543
544
545
546

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)
547
548
        # There are two LoRA layers
        self.n_slices = len(self.base_layer.output_sizes)
549
550

    def create_lora_weights(
551
552
553
554
555
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
556
557
558
559
        """
        The main reason for overriding this function is to enhance  code 
        maintainability.
        """
560
        self.lora_config = lora_config
561
562

        if not (len(self.base_layer.output_sizes) == self.n_slices == 2
563
564
565
566
567
568
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()
569
570
571
572
573
        self.tp_rank = get_tensor_model_parallel_rank()

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
574
575
576
577
578

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
579
                lora_a_output_size_per_partition,
580
                self.input_size,
581
                dtype=lora_config.lora_dtype,
582
                device=self.device,
583
            ) for _ in range(self.n_slices))
584
585
586
587
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
588
                self.output_size // 2,
589
590
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
591
                device=self.device,
592
            ) for _ in range(self.n_slices))
593
        if lora_config.bias_enabled:
594
            self.lora_bias_stacked = tuple(
595
596
597
598
599
600
                torch.zeros(
                    max_loras,
                    1,
                    self.output_size // 2,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
601
                ) for _ in range(self.n_slices))
602
        self.output_dim = self.lora_b_stacked[0].shape[2]
603
        self.output_slices = (self.output_dim, self.output_dim)
604

605
606
607
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
608
609
        return lora_a

610
611
612
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
613
        #NOTE: lora_b contains 2 subloras, and each sublora could be None.
614
615
616
617
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = [
618
619
            lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
            lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
620
621
622
        ]
        return lora_b

623
624
625
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
626
        # NOTE : each bias could be None.
627
628
629
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
630
631
632
633
        bias = [
            bias[0][start_idx:end_idx] if bias[0] is not None else None,
            bias[1][start_idx:end_idx] if bias[1] is not None else None
        ]
634
635
        return bias

636
637
638
639
640
641
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
642
        lora_bias: Optional[torch.Tensor] = None,
643
644
645
646
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
647
648
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
649
650
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
651
652
653
654
655
656
657
658

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
659
660
661
662
663
        if lora_bias is not None and lora_bias[0] is not None:
            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_(
                lora_bias[0].T, non_blocking=True)
664
665
666
667
668
669
670
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)
671
672
673
674
675
        if lora_bias is not None and lora_bias[1] is not None:
            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
                lora_bias[1].T, non_blocking=True)
676

677
    @classmethod
678
    @_not_fully_sharded_can_replace
679
680
681
682
683
684
685
686
687
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is MergedColumnParallelLinear
                and len(packed_modules_list) == 2)
688

689
690

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
691
    """
692
    ColumnParallelLinear layer that is specifically designed for
693
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
694
    only contains a single LoRA within their qkv_proj layer.
695

696
    During inference with Tensor Parallel, the weights of lora_b
697
    must be accurately partitioned according to the respective ranks.
698

699
700
701
702
703
704
705
706
707
708
709
710
711
712
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)
713
714
        # There is only one LoRA layer
        self.n_slices = 1
715

716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

749
    @classmethod
750
    @_not_fully_sharded_can_replace
751
752
753
754
755
756
757
758
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
759
760
761
762
763
764
765
766
767
768
769
770
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
771
772
773
774
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
775
776

    def create_lora_weights(
777
778
779
780
781
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
782
783
784
785
        """
        The main reason for overloading this function is to handle inconsistent 
        weight dimensions in qkv lora.
        """
786
        self.lora_config = lora_config
787
788
789
790
791

        if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
            raise ValueError(
                "LoRAColumnParallelLinear3Slice requires 3 slices.")

792
793
794
795
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
796
797
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
798

799
800
801
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
802
803
804
805
806
        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
807
                lora_a_output_size_per_partition,
808
                self.input_size,
809
                dtype=lora_config.lora_dtype,
810
                device=self.device,
811
812
813
814
            ),
            torch.zeros(
                max_loras,
                1,
815
                lora_a_output_size_per_partition,
816
                self.input_size,
817
                dtype=lora_config.lora_dtype,
818
                device=self.device,
819
820
821
822
            ),
            torch.zeros(
                max_loras,
                1,
823
                lora_a_output_size_per_partition,
824
                self.input_size,
825
                dtype=lora_config.lora_dtype,
826
                device=self.device,
827
828
829
830
831
832
833
834
835
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
836
                device=self.device,
837
838
839
840
841
842
843
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
844
                device=self.device,
845
846
847
848
849
850
851
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
852
                device=self.device,
853
854
            ),
        )
855
        if lora_config.bias_enabled:
856
            self.lora_bias_stacked = (
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
                torch.zeros(
                    max_loras,
                    1,
                    self.q_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
                torch.zeros(
                    max_loras,
                    1,
                    self.kv_proj_shard_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ),
            )
879
880
881
882
883
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
884
885
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
886
        # lazily initialized.
887
        self.indices: torch.Tensor
888
        self.indices_len: List[int]
889

890
891
892
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
893
894
        return lora_a

895
896
897
898
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        lora_b_q, lora_b_k, lora_b_v = None, None, None
899
900
901
        if lora_b[0] is not None:
            lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                 self.q_shard_id:self.q_proj_shard_size *
902
                                 (self.q_shard_id + 1), ]
903
904
905
        if lora_b[1] is not None:
            lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
906
                                 (self.kv_shard_id + 1), ]
907
908
909
        if lora_b[2] is not None:
            lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
910
                                 (self.kv_shard_id + 1), ]
911
912
913
        lora_b = [lora_b_q, lora_b_k, lora_b_v]
        return lora_b

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
        bias_q, bias_k, bias_v = bias
        if bias_q is not None:
            bias_q = bias_q[self.q_proj_shard_size *
                            self.q_shard_id:self.q_proj_shard_size *
                            (self.q_shard_id + 1)]
        if bias_k is not None:
            bias_k = bias_k[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        if bias_v is not None:
            bias_v = bias_v[self.kv_proj_shard_size *
                            self.kv_shard_id:self.kv_proj_shard_size *
                            (self.kv_shard_id + 1)]
        bias = [bias_q, bias_k, bias_v]
        return bias

933
934
935
936
937
938
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
939
        lora_bias: Optional[torch.Tensor] = None,
940
941
942
943
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
944
945
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
946
947
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963

        if lora_b[0] is not None:
            lora_b_q = lora_b[0]
            self.lora_b_stacked[0][
                index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                    lora_b_q.T, non_blocking=True)
        if lora_b[1] is not None:
            lora_b_k = lora_b[1]
            self.lora_b_stacked[1][
                index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                    lora_b_k.T, non_blocking=True)
        if lora_b[2] is not None:
            lora_b_v = lora_b[2]
            self.lora_b_stacked[2][
                index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                    lora_b_v.T, non_blocking=True)
964
965
966
967
968
969
970
971
972
973
974
975
976
977

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
        if lora_a[2] is not None:
            self.lora_a_stacked[2][
                index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                    lora_a[2].T, non_blocking=True)

978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
        if lora_bias is not None:
            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            if lora_bias[0] is not None:
                self.lora_bias_stacked[0][index,
                                          0, :lora_bias[0].shape[0]].copy_(
                                              lora_bias[0].T,
                                              non_blocking=True)
            if lora_bias[1] is not None:
                self.lora_bias_stacked[1][index,
                                          0, :lora_bias[1].shape[0]].copy_(
                                              lora_bias[1].T,
                                              non_blocking=True)
            if lora_bias[2] is not None:
                self.lora_bias_stacked[2][index,
                                          0, :lora_bias[2].shape[0]].copy_(
                                              lora_bias[2].T,
                                              non_blocking=True)
996

997
    @classmethod
998
    @_not_fully_sharded_can_replace
999
1000
1001
1002
1003
1004
1005
1006
1007
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is QKVParallelLinear
                and len(packed_modules_list) == 3)
1008

1009

1010
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
1011
1012

    def __init__(self, base_layer: RowParallelLinear) -> None:
1013
1014
1015
1016
        super().__init__(base_layer)

        self.tp_size = get_tensor_model_parallel_world_size()
        # reset input_size
1017
1018
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
1019

1020
        self.tp_rank = get_tensor_model_parallel_rank()
1021
1022
        # There is only one LoRA layer.
        self.n_slices = 1
1023

1024
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
1025

1026
        shard_size = self.input_size
1027
1028
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
1029
1030
1031
1032
1033
1034
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

1035
1036
1037
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
1057
            input_parallel = splitted_input[self.tp_rank].contiguous()
1058
1059

        # Matrix multiply.
1060
        output_parallel = self.apply(input_parallel)
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
1077
1078
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
1079

1080
    @classmethod
1081
    @_not_fully_sharded_can_replace
1082
1083
1084
1085
1086
1087
1088
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1089
1090
        return type(source_layer) is RowParallelLinear

1091

1092
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
1106

1107
1108
1109
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
1110
1111
1112
1113
1114
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
1115
1116
1117
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
1118

1119
    @property
1120
1121
    def logits_as_input(self):
        return self.base_layer.logits_as_input
1122

1123
1124
1125
1126
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1127
1128
1129
1130
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1131
1132
1133
1134
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1135
1136
1137
1138
    @property
    def use_gather(self):
        return self.base_layer.use_gather

1139
1140
1141
1142
1143
1144
1145
1146
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1147
1148
1149
1150
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1151
1152
1153
1154
1155
1156
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1157
1158
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1159
            raise ValueError("When using LoRA, vocab size must be "
1160
                             "32000 >= vocab_size <= 257024")
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1190
1191
1192
1193
1194
1195
1196
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1209
        bias: Optional[torch.Tensor] = None,
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1226
        lm_head: VocabParallelEmbedding,
1227
        embedding_bias: Optional[torch.Tensor] = None,
1228
    ) -> Optional[torch.Tensor]:
1229
        # Get the logits for the next tokens.
1230
        logits = lm_head.linear_method.apply(lm_head, hidden_states)
1231
1232
1233
1234
1235
1236
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1268
        indices_padded = self.punica_wrapper.sampler_indices_padded
1269
1270
1271
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1272
1273
1274
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1275
1276
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1277
               lora_logits.shape[1]] = lora_logits
1278
1279
1280
1281
1282

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1283
1284
1285
1286
1287
1288
1289
1290

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1291
    @classmethod
1292
1293
1294
1295
1296
1297
1298
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1299
1300
        # Special handling for the LogitsProcessor.
        return False
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328


class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1329
1330
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1354
        bias: Optional[torch.Tensor] = None,
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1368
1369
            offsets=self.punica_wrapper.long_lora_indices,
        )
1370
1371
1372
1373
1374
1375

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1376
1377
1378
1379
1380
1381
1382
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1383
        """Returns True if the layer can be replaced by this LoRA layer."""
1384
1385
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1386
1387
1388

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()