layers.py 44.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
7
8
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

13
from vllm.adapter_commons.layers import AdapterMapping
14
from vllm.config import LoRAConfig
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
19
                              tensor_model_parallel_all_reduce)
20
from vllm.distributed.utils import divide
21
# yapf: disable
22
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23
                                               LinearBase,
24
                                               MergedColumnParallelLinear,
25
                                               QKVParallelLinear,
26
                                               ReplicatedLinear,
27
                                               RowParallelLinear)
28
# yapf: enable
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
31
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
    VocabParallelEmbedding)
34
from vllm.platforms import current_platform
35
36

if TYPE_CHECKING:
37
    from vllm.lora.punica_wrapper import PunicaWrapperBase
38
39


40
41
42
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
43
    # unquantizedLinear
44
45
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
46
47
48
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
49
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
50
51
52
53
54
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
55
56
57
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
Jee Li's avatar
Jee Li committed
58
59
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
60
61


62
63
64
65
66
67
68
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
69
70
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
71
72
73
74
75
76
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


77
@dataclass
78
class LoRAMapping(AdapterMapping):
79
    is_prefill: bool = False
80
81
82
83


class BaseLayerWithLoRA(nn.Module):

84
85
86
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
87
88
89
        """Slice lora a if splitting for tensor parallelism."""
        ...

90
91
92
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
93
94
95
        """Slice lora b if splitting with tensor parallelism."""
        ...

96
    def create_lora_weights(
97
98
99
100
101
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
102
103
104
105
106
107
108
109
110
111
112
113
114
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
115
        bias: Optional[torch.Tensor] = None,
116
117
118
119
120
121
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
122
        punica_wrapper,
123
    ):
124
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper
125

126
    @classmethod
127
128
129
130
131
132
133
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
134
135
136
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

137
138
139
140
141
142

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
143
144
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
145
146
147
148
149
150
151

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

152
        if self.base_layer.num_added_embeddings_per_partition > 0:
153
            # We can start adding lora weights
154
155
156
157
158
159
160
161
162
163
164
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
214
        bias: Optional[torch.Tensor] = None,
215
216
217
218
219
220
221
222
223
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
224
225
226
227
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
228
229
230
231
232
233
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
234
                    self.embeddings_tensors.shape[2],
235
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
236
                assert self.embeddings_weights is not None
237
238
239
240
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
241
242
        embeddings_indices = self.punica_wrapper.embeddings_indices
        indices = embeddings_indices[1].view_as(x)
243
244
245
246
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
247
        indices = embeddings_indices[0].view_as(x)
248
249
250
251
252
253
254
255
256
257
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
258
259
260
                full_lora_a_embeddings.shape[1],
                -1,
            )
261
262
263
264
        self.punica_wrapper.add_lora_embedding(full_output,
                                               full_lora_a_embeddings,
                                               self.lora_b_stacked,
                                               add_input=True)
265
266
        return full_output.view_as(full_output_org)

267
    @classmethod
268
269
270
271
272
273
274
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
275
276
        return type(source_layer) is VocabParallelEmbedding

277

278
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
279

280
    def __init__(self, base_layer: LinearBase):
281
282
283
284
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.device = _get_lora_device(self.base_layer)
285
286
287
288
289
290
        self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None

        self.output_slices: Tuple[int, ...]
        self.tp_size: int
        self.output_size: int
        self.n_slices: int
291
292
293
294
295
296
297
298

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
320
321
                max_loras,
                1,
322
323
                lora_a_out_size,
                self.input_size,
324
325
                dtype=lora_config.lora_dtype,
                device=self.device,
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )
347
348

    def reset_lora(self, index: int):
349
350
351
352
353
354
355
356
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
                self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0
357
358
359
360
361
362
363

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
364
        lora_bias: Optional[torch.Tensor] = None,
365
    ):
366
        # Except for QKVParallelLinearWithLoRA and
367
368
369
370
371
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                   lora_a.T, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                   lora_b.T, non_blocking=True)
        if lora_bias is not None:

            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias.T, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
397
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
398
399
400
401
402
403
404
405

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

406
407
408
409
        self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
                                            self.lora_b_stacked,
                                            self.lora_bias_stacked, 1.0,
                                            self.output_slices)
410
411
        return output

412
413
414
415
416
417
418
419
420
421

class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(base_layer, )
        # To ensure interface compatibility, set to 1 always.
        self.tp_size = 1
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

422
423
    def forward(
        self, input_: torch.Tensor
424
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
442
443
444
445

        if not self.base_layer.return_bias:
            return output

446
447
        return output, output_bias

448
449
    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
450
451
452
453
454
455
456
457
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
458
        return type(source_layer) is ReplicatedLinear
459
460


461
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
462
463
464
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
465
466
467
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
468
    """
469
470

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
471
        super().__init__(base_layer)
472
473
474
475
476
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear
477
        self.tp_size = get_tensor_model_parallel_world_size()
478
        self.output_size = self.base_layer.output_size_per_partition
479
480
        # There is only one LoRA layer
        self.n_slices = 1
481

482
483
484
485
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
502
            shard_size = self.output_size
503
504
505
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
506
507
        return lora_b

508
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
509
        # TODO: Fix the slicing logic of bias.
510
511
512
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
513
        shard_size = self.output_size
514
515
516
517
518
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

519
520
    def forward(
        self, input_: torch.Tensor
521
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
522
523
524
525
526
527
528
529
530
531
532
533
534
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
535
        output_parallel = self.apply(input_, bias)
536
537
538
539
540
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
541
542
543
544

        if not self.base_layer.return_bias:
            return output

545
546
547
548
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

549
    @classmethod
550
    @_not_fully_sharded_can_replace
551
552
553
554
555
556
557
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
558
559
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
560
561
            and len(packed_modules_list) == 1)

562
563
564
565
566
567
568
569
570
571

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

572
573
574
    def __init__(
        self, base_layer: Union[MergedColumnParallelLinear,
                                QKVParallelLinear]) -> None:
575
        super().__init__(base_layer)
576
        # There are two LoRA layers
577
578
579
580
581
582
583
584
585
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes)
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank, ) * self.n_slices
586
587

    def create_lora_weights(
588
589
590
591
592
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
593
594
595
596
        """
        The main reason for overriding this function is to enhance  code 
        maintainability.
        """
597
        self.lora_config = lora_config
598

599
600
601
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
602
603
604
605
606

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
607
                lora_a_output_size_per_partition,
608
                self.input_size,
609
                dtype=lora_config.lora_dtype,
610
                device=self.device,
611
            ) for _ in range(self.n_slices))
612
613
614
615
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
616
                output_size,
617
618
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
619
                device=self.device,
620
            ) for output_size in self.output_slices)
621
        if lora_config.bias_enabled:
622
            self.lora_bias_stacked = tuple(
623
624
625
                torch.zeros(
                    max_loras,
                    1,
626
                    output_size,
627
628
                    dtype=lora_config.lora_dtype,
                    device=self.device,
629
                ) for output_size in self.output_slices)
630

631
632
633
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
634
635
        return lora_a

636
637
638
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
639
640
641
642
643
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (lora_b_i := lora_b[i]) is not None:
                lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
                                     (shard_id + 1)]
644
645
        return lora_b

646
647
648
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
649
650
651
652
653
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (bias_i := bias[i]) is not None:
                bias[i] = bias_i[shard_size * shard_id:shard_size *
                                 (shard_id + 1)]
654
655
        return bias

656
657
658
659
660
661
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
662
        lora_bias: Optional[torch.Tensor] = None,
663
664
665
666
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
667
668
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
669
670
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
671

672
673
674
675
676
677
678
679
680
681
682
        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
                        lora_a_i.T, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
                        lora_b_i.T, non_blocking=True)

        if lora_bias is not None:
683
684
            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
685
686
687
688
689
690
            for i in range(self.n_slices):
                if (lora_bias_i := lora_bias[i]) is not None:
                    self.lora_bias_stacked[i][index,
                                              0, :lora_bias_i.shape[0]].copy_(
                                                  lora_bias_i.T,
                                                  non_blocking=True)
691

692
    @classmethod
693
    @_not_fully_sharded_can_replace
694
695
696
697
698
699
700
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
701
        return (type(source_layer) is MergedColumnParallelLinear
702
                and len(packed_modules_list) == 2)
703

704

705
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
706
    """
707
    ColumnParallelLinear layer that is specifically designed for
708
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
709
    only contains a single LoRA within their qkv_proj layer.
710

711
    During inference with Tensor Parallel, the weights of lora_b
712
    must be accurately partitioned according to the respective ranks.
713

714
715
716
717
718
719
720
721
722
723
724
725
726
727
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)
728
729
        # There is only one LoRA layer
        self.n_slices = 1
730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

764
    @classmethod
765
    @_not_fully_sharded_can_replace
766
767
768
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
769
        return type(source_layer) is QKVParallelLinear and len(
770
771
772
            packed_modules_list) == 1


773
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
774
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
775
776
777
778
779
780
781
782
783
784
785
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
786
787
788
789
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
790
791
792
793
794

        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
795
796
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
797

798
799
800
801
802
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
803
804
805
806
807
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )
808

809
    def create_lora_weights(
810
        self,
811
812
813
814
815
816
817
818
819
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent 
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)
820

821
    @classmethod
822
    @_not_fully_sharded_can_replace
823
824
825
826
827
828
829
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
830
        return (type(source_layer) is QKVParallelLinear
831
                and len(packed_modules_list) == 3)
832

833

834
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
835
836

    def __init__(self, base_layer: RowParallelLinear) -> None:
837
838
839
840
        super().__init__(base_layer)

        self.tp_size = get_tensor_model_parallel_world_size()
        # reset input_size
841
842
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
843

844
        self.tp_rank = get_tensor_model_parallel_rank()
845
846
        # There is only one LoRA layer.
        self.n_slices = 1
847

848
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
849

850
        shard_size = self.input_size
851
852
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
853
854
855
856
857
858
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

859
860
861
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

862
863
    def forward(
        self, input_: torch.Tensor
864
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
883
            input_parallel = splitted_input[self.tp_rank].contiguous()
884
885

        # Matrix multiply.
886
        output_parallel = self.apply(input_parallel)
887
888
889
890
891
892
893
894
895
896
897
898
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
899
900
901
902

        if not self.base_layer.return_bias:
            return output

903
904
905
906
        return output, output_bias

    @property
    def weight(self):
907
908
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
909

910
    @classmethod
911
    @_not_fully_sharded_can_replace
912
913
914
915
916
917
918
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
919
        return type(source_layer) is RowParallelLinear
920

921

922
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
923
924
925
926
927
928
929
930
931
932
933
934
935
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
936

937
938
939
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
940
941
942
943
944
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
945
946
947
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
948

949
    @property
950
951
    def logits_as_input(self):
        return self.base_layer.logits_as_input
952

953
954
955
956
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

957
958
959
960
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
961
962
963
964
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

965
    @property
966
967
    def use_all_gather(self):
        return self.base_layer.use_all_gather
968

969
970
971
972
973
974
975
976
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

977
978
979
980
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

981
982
983
984
985
986
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
987
988
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
989
            raise ValueError("When using LoRA, vocab size must be "
990
                             "32000 >= vocab_size <= 257024")
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1020
1021
1022
1023
1024
1025
1026
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1039
        bias: Optional[torch.Tensor] = None,
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
1050
1051
1052
1053
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ] = embeddings_tensor
1054
1055
1056
1057

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1058
        lm_head: VocabParallelEmbedding,
1059
        embedding_bias: Optional[torch.Tensor] = None,
1060
    ) -> Optional[torch.Tensor]:
1061
        # Get the logits for the next tokens.
1062
        logits = lm_head.quant_method.apply(lm_head, hidden_states)
1063
1064
        if embedding_bias is not None:
            logits += embedding_bias
1065
1066
1067
1068

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

1069
1070
1071
        if logits is None:
            return None

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1103
        indices_padded = self.punica_wrapper.sampler_indices_padded
1104
1105
1106
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1107
1108
1109
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1110
1111
1112
1113
1114

        # HPU needs special handling to prune out dummy samples.
        if current_platform.is_hpu():
            lora_logits = lora_logits[:logits.shape[0], :]

1115
1116
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1117
               lora_logits.shape[1]] = lora_logits
1118
1119
1120
1121
1122

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1123
1124
1125
1126
1127
1128
1129
1130

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1131
    @classmethod
1132
1133
1134
1135
1136
1137
1138
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1139
1140
        # Special handling for the LogitsProcessor.
        return False
1141
1142


1143
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1169
1170
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1194
        bias: Optional[torch.Tensor] = None,
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1208
1209
            offsets=self.punica_wrapper.long_lora_indices,
        )
1210
1211
1212
1213
1214
1215

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1216
1217
1218
1219
1220
1221
1222
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1223
        """Returns True if the layer can be replaced by this LoRA layer."""
1224
1225
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1226
1227
1228

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()