layers.py 42.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Optional, Union, cast
8
9
10
11
12
13

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

14
from vllm.adapter_commons.layers import AdapterMapping
15
from vllm.config import LoRAConfig
16
17
18
19
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
20
                              tensor_model_parallel_all_reduce)
21
from vllm.distributed.utils import divide
22
# yapf: disable
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
                                               LinearBase,
25
                                               MergedColumnParallelLinear,
26
                                               QKVParallelLinear,
27
                                               ReplicatedLinear,
28
                                               RowParallelLinear)
29
# yapf: enable
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
    VocabParallelEmbedding)
33
from vllm.platforms import current_platform
34
35

if TYPE_CHECKING:
36
    from vllm.lora.punica_wrapper import PunicaWrapperBase
37
38


39
40
41
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
42
    # unquantizedLinear
43
44
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
45
46
47
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
48
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
49
50
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
51
52
53
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
Jee Li's avatar
Jee Li committed
54
55
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
56
57


58
59
60
61
62
63
64
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
65
66
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
67
68
69
70
71
72
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


73
@dataclass
74
class LoRAMapping(AdapterMapping):
75
    is_prefill: bool = False
76
77
78
79


class BaseLayerWithLoRA(nn.Module):

80
    def slice_lora_a(
81
82
        self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
83
84
85
        """Slice lora a if splitting for tensor parallelism."""
        ...

86
    def slice_lora_b(
87
88
        self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
89
90
91
        """Slice lora b if splitting with tensor parallelism."""
        ...

92
    def create_lora_weights(
93
94
95
96
97
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
98
99
100
101
102
103
104
105
106
107
108
109
110
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
111
        bias: Optional[torch.Tensor] = None,
112
113
114
115
116
117
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
118
        punica_wrapper,
119
    ):
120
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper
121

122
    @classmethod
123
124
125
126
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
127
        packed_modules_list: list,
128
129
        model_config: Optional[PretrainedConfig],
    ) -> bool:
130
131
132
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

133
134
135
136
137
138

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
139
        self.embeddings_slice: Optional[tuple[int, int]]
140
        self.embeddings_weights: Optional[torch.Tensor]
141
142
143
144
145
146
147

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

148
        if self.base_layer.num_added_embeddings_per_partition > 0:
149
            # We can start adding lora weights
150
151
152
153
154
155
156
157
158
159
160
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
210
        bias: Optional[torch.Tensor] = None,
211
212
213
214
215
216
217
218
219
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
220
221
222
223
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
224
225
226
227
228
229
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
230
                    self.embeddings_tensors.shape[2],
231
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
232
                assert self.embeddings_weights is not None
233
234
235
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
236
237
238
        added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
                                        1, 0)

239
240
241
242
243
244
        # NB: Don't use torch.narrow here. torch.narrow triggers some
        # Dynamic Shape specialization in torch.compile
        num_tokens = x.shape[0]
        indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
        indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]

245
        full_lora_a_embeddings = F.embedding(
246
            x + indices_1,
247
248
            self.lora_a_stacked_2d,
        )
249
        full_output = self.base_layer.forward(x +
250
                                              (indices_0 * added_tokens_mask))
251
252
253
254
255
256
257
258

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
259
260
261
                full_lora_a_embeddings.shape[1],
                -1,
            )
262
263
264
265
266
267
268
269
270
271
272

        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_embedding(
                full_output,
                full_lora_a_embeddings,
                self.lora_b_stacked,
                add_input=True)

        if not current_platform.can_update_inplace():
            full_output = lora_output

273
274
        return full_output.view_as(full_output_org)

275
    @classmethod
276
277
278
279
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
280
        packed_modules_list: list,
281
282
        model_config: Optional[PretrainedConfig],
    ) -> bool:
283
284
        return type(source_layer) is VocabParallelEmbedding

285
286
287
288
    @property
    def weight(self):
        return self.base_layer.weight

289

290
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
291

292
    def __init__(self, base_layer: LinearBase):
293
294
295
296
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.device = _get_lora_device(self.base_layer)
297
        self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
298

299
        self.output_slices: tuple[int, ...]
300
301
302
        self.tp_size: int
        self.output_size: int
        self.n_slices: int
303
304
305
306
307
308
309
310

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
332
333
                max_loras,
                1,
334
335
                lora_a_out_size,
                self.input_size,
336
337
                dtype=lora_config.lora_dtype,
                device=self.device,
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )
359
360

    def reset_lora(self, index: int):
361
362
363
364
365
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
366
                self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
367
368
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0
369
370
371
372
373
374
375

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
376
        lora_bias: Optional[torch.Tensor] = None,
377
    ):
378
        # Except for QKVParallelLinearWithLoRA and
379
380
381
382
383
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)
384

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                   lora_a.T, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                   lora_b.T, non_blocking=True)
        if lora_bias is not None:

400
            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
401
402
403
404
405
406
407
408
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias.T, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
409
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
410
411
412
413
414
415
416
417

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

418
419
420
421
422
423
424
        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_linear(
                output, x, self.lora_a_stacked, self.lora_b_stacked,
                self.lora_bias_stacked, 1.0, self.output_slices)
        if not current_platform.can_update_inplace():
            output = lora_output

425
426
        return output

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    @property
    def weight(self) -> torch.Tensor:

        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        # HQQ marlin
        elif hasattr(self.base_layer, "W_q"):
            return self.base_layer.W_q
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
    def bias(self) -> Optional[torch.Tensor]:
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None

455
456
457
458
459
460
461
462
463
464

class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(base_layer, )
        # To ensure interface compatibility, set to 1 always.
        self.tp_size = 1
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

465
466
    def forward(
        self, input_: torch.Tensor
467
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
485
486
487
488

        if not self.base_layer.return_bias:
            return output

489
490
        return output, output_bias

491
492
    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
493
494
495
496
497
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
498
        packed_modules_list: list,
499
500
        model_config: Optional[PretrainedConfig],
    ) -> bool:
501
        return type(source_layer) is ReplicatedLinear
502
503


504
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
505
506
507
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
508
509
510
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
511
    """
512
513

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
514
        super().__init__(base_layer)
515
516
517
518
519
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear
520
        self.tp_size = get_tensor_model_parallel_world_size()
521
        self.output_size = self.base_layer.output_size_per_partition
522
523
        # There is only one LoRA layer
        self.n_slices = 1
524

525
526
527
528
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
545
            shard_size = self.output_size
546
547
548
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
549
550
        return lora_b

551
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
552
        # TODO: Fix the slicing logic of bias.
553
554
555
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
556
        shard_size = self.output_size
557
558
559
560
561
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

562
563
    def forward(
        self, input_: torch.Tensor
564
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
565
566
567
568
569
570
571
572
573
574
575
576
577
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
578
        output_parallel = self.apply(input_, bias)
579
580
581
582
583
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
584
585
586
587

        if not self.base_layer.return_bias:
            return output

588
589
590
591
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

592
    @classmethod
593
    @_not_fully_sharded_can_replace
594
595
596
597
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
598
        packed_modules_list: list,
599
600
        model_config: Optional[PretrainedConfig],
    ) -> bool:
601
602
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
603
604
            and len(packed_modules_list) == 1)

605
606
607
608
609
610
611
612
613
614

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

615
616
617
    def __init__(
        self, base_layer: Union[MergedColumnParallelLinear,
                                QKVParallelLinear]) -> None:
618
        super().__init__(base_layer)
619
        # There are two LoRA layers
620
621
622
623
624
625
626
627
628
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes)
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank, ) * self.n_slices
629
630

    def create_lora_weights(
631
632
633
634
635
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
636
637
638
639
        """
        The main reason for overriding this function is to enhance  code 
        maintainability.
        """
640
        self.lora_config = lora_config
641

642
643
644
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
645
646
647
648
649

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
650
                lora_a_output_size_per_partition,
651
                self.input_size,
652
                dtype=lora_config.lora_dtype,
653
                device=self.device,
654
            ) for _ in range(self.n_slices))
655
656
657
658
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
659
                output_size,
660
661
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
662
                device=self.device,
663
            ) for output_size in self.output_slices)
664
        if lora_config.bias_enabled:
665
            self.lora_bias_stacked = tuple(
666
667
668
                torch.zeros(
                    max_loras,
                    1,
669
                    output_size,
670
671
                    dtype=lora_config.lora_dtype,
                    device=self.device,
672
                ) for output_size in self.output_slices)
673

674
    def slice_lora_a(
675
676
        self, lora_a: list[Union[torch.Tensor, None]]
    ) -> list[Union[torch.Tensor, None]]:
677
678
        return lora_a

679
    def slice_lora_b(
680
681
        self, lora_b: list[Union[torch.Tensor, None]]
    ) -> list[Union[torch.Tensor, None]]:
682
        sliced_lora_b = [None] * self.n_slices
683
684
685
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (lora_b_i := lora_b[i]) is not None:
686
687
688
689
                sliced_lora_b[i] = lora_b_i[:,
                                            shard_size * shard_id:shard_size *
                                            (shard_id + 1)]
        return sliced_lora_b
690

691
    def slice_bias(
692
693
        self, bias: list[Union[torch.Tensor,
                               None]]) -> list[Union[torch.Tensor, None]]:
694
695
696
697
698
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (bias_i := bias[i]) is not None:
                bias[i] = bias_i[shard_size * shard_id:shard_size *
                                 (shard_id + 1)]
699
700
        return bias

701
702
703
704
705
706
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
707
        lora_bias: Optional[torch.Tensor] = None,
708
709
710
711
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
712
713
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
714
715
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
716

717
718
719
720
721
722
723
724
725
726
727
        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
                        lora_a_i.T, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
                        lora_b_i.T, non_blocking=True)

        if lora_bias is not None:
728
            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
729
                                          self.lora_bias_stacked)
730
731
732
733
734
735
            for i in range(self.n_slices):
                if (lora_bias_i := lora_bias[i]) is not None:
                    self.lora_bias_stacked[i][index,
                                              0, :lora_bias_i.shape[0]].copy_(
                                                  lora_bias_i.T,
                                                  non_blocking=True)
736

737
    @classmethod
738
    @_not_fully_sharded_can_replace
739
740
741
742
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
743
        packed_modules_list: list,
744
745
        model_config: Optional[PretrainedConfig],
    ) -> bool:
746
        return (type(source_layer) is MergedColumnParallelLinear
747
                and len(packed_modules_list) == 2)
748

749

750
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
751
    """
752
    ColumnParallelLinear layer that is specifically designed for
753
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
754
    only contains a single LoRA within their qkv_proj layer.
755

756
    During inference with Tensor Parallel, the weights of lora_b
757
    must be accurately partitioned according to the respective ranks.
758

759
760
761
762
763
764
765
766
767
768
769
770
771
772
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)
773
774
        # There is only one LoRA layer
        self.n_slices = 1
775

776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

809
    @classmethod
810
    @_not_fully_sharded_can_replace
811
    def can_replace_layer(cls, source_layer: nn.Module,
812
                          lora_config: LoRAConfig, packed_modules_list: list,
813
                          model_config: Optional[PretrainedConfig]) -> bool:
814
        return type(source_layer) is QKVParallelLinear and len(
815
816
817
            packed_modules_list) == 1


818
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
819
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
820
821
822
823
824
825
826
827
828
829
830
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
831
832
833
834
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
835
836
837
838
839

        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
840
841
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
842

843
844
845
846
847
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
848
849
850
851
852
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )
853

854
    def create_lora_weights(
855
        self,
856
857
858
859
860
861
862
863
864
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent 
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)
865

866
    @classmethod
867
    @_not_fully_sharded_can_replace
868
869
870
871
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
872
        packed_modules_list: list,
873
874
        model_config: Optional[PretrainedConfig],
    ) -> bool:
875
        return (type(source_layer) is QKVParallelLinear
876
                and len(packed_modules_list) == 3)
877

878

879
880
881
882
883
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
    pass


884
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
885
886

    def __init__(self, base_layer: RowParallelLinear) -> None:
887
888
889
890
        super().__init__(base_layer)

        self.tp_size = get_tensor_model_parallel_world_size()
        # reset input_size
891
892
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
893

894
        self.tp_rank = get_tensor_model_parallel_rank()
895
896
        # There is only one LoRA layer.
        self.n_slices = 1
897

898
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
899

900
        shard_size = self.input_size
901
902
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
903
904
905
906
907
908
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

909
910
911
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

912
913
    def forward(
        self, input_: torch.Tensor
914
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
915
916
917
918
919
920
921
922
923
924
925
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
926
        # set up backprop all-reduce.
927
928
929
930
931
932
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
933
            input_parallel = splitted_input[self.tp_rank].contiguous()
934
935

        # Matrix multiply.
936
        output_parallel = self.apply(input_parallel)
937
938
939
940
941
942
943
944
945
946
947
948
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
949
950
951
952

        if not self.base_layer.return_bias:
            return output

953
954
        return output, output_bias

955
    @classmethod
956
    @_not_fully_sharded_can_replace
957
958
959
960
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
961
        packed_modules_list: list,
962
963
        model_config: Optional[PretrainedConfig],
    ) -> bool:
964
        return type(source_layer) is RowParallelLinear
965

966

967
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
968
969
970
971
972
973
974
975
976
977
978
979
980
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
981

982
983
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
984
                 sharded_to_full_mapping: Optional[list[int]]) -> None:
985
986
987
988
989
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
990
991
992
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
993

994
    @property
995
996
    def logits_as_input(self):
        return self.base_layer.logits_as_input
997

998
999
1000
1001
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1002
1003
1004
1005
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1006
1007
1008
1009
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1010
    @property
1011
1012
    def use_all_gather(self):
        return self.base_layer.use_all_gather
1013

1014
1015
1016
1017
1018
1019
1020
1021
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1022
1023
1024
1025
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1026
1027
1028
1029
1030
1031
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1032
1033
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1034
            raise ValueError("When using LoRA, vocab size must be "
1035
                             "32000 >= vocab_size <= 257024")
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1065
1066
1067
1068
1069
1070
1071
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1084
        bias: Optional[torch.Tensor] = None,
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
1095
1096
1097
1098
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ] = embeddings_tensor
1099
1100
1101
1102

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1103
        lm_head: VocabParallelEmbedding,
1104
        embedding_bias: Optional[torch.Tensor] = None,
1105
    ) -> Optional[torch.Tensor]:
1106
        # Get the logits for the next tokens.
1107
        logits = lm_head.quant_method.apply(lm_head, hidden_states)
1108
1109
        if embedding_bias is not None:
            logits += embedding_bias
1110
1111
1112
1113

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

1114
1115
1116
        if logits is None:
            return None

1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
1146
1147
1148
1149
1150

        neg_inf, pos_inf = current_platform.get_infinity_values(
            lora_logits.dtype)

        lora_logits[-1] = neg_inf
1151
        lora_logits = lora_logits.mT
1152
        indices_padded = self.punica_wrapper.sampler_indices_padded
1153
1154
1155
1156

        if current_platform.is_tpu():
            indices_padded = indices_padded[:logits.size(0)]

1157
1158
1159
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1160
1161
1162
        ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
                                                      posinf=pos_inf,
                                                      neginf=neg_inf))
1163

1164
1165
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1166
               lora_logits.shape[1]] = lora_logits
1167

1168
1169
1170
1171
1172
1173
1174
        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_logits(
                logits, hidden_states, self.lora_a_stacked,
                self.lora_b_stacked, 1.0)

        if not current_platform.can_update_inplace():
            logits = lora_output
1175
1176
1177
1178
1179
1180
1181
1182

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1183
    @classmethod
1184
1185
1186
1187
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
1188
        packed_modules_list: list,
1189
1190
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1191
1192
        # Special handling for the LogitsProcessor.
        return False