layers.py 42.8 KB
Newer Older
1
2
3
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
6
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

11
from vllm.adapter_commons.layers import AdapterMapping
12
from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
from vllm.distributed.utils import divide
20
from vllm.lora.punica import PunicaWrapper
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22
                                               MergedColumnParallelLinear,
23
                                               QKVParallelLinear,
24
                                               RowParallelLinear)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
27
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
28
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
    VocabParallelEmbedding)
30
31
32
33
34

if TYPE_CHECKING:
    pass


35
36
37
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
38
    # unquantizedLinear
39
40
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
Jee Li's avatar
Jee Li committed
41
42
43
44
45
46
47
48
    # GPTQ/AWQ/SqueezeLLM
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
49
50


51
52
53
54
55
56
57
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
58
59
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
60
61
62
63
64
65
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


66
@dataclass
67
class LoRAMapping(AdapterMapping):
68
    is_prefill: bool = False
69
70
71
72


class BaseLayerWithLoRA(nn.Module):

73
74
75
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
76
77
78
        """Slice lora a if splitting for tensor parallelism."""
        ...

79
80
81
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
82
83
84
        """Slice lora b if splitting with tensor parallelism."""
        ...

85
    def create_lora_weights(
86
87
88
89
90
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
110
        punica_wrapper: PunicaWrapper,
111
    ):
112
        self.punica_wrapper: PunicaWrapper = punica_wrapper
113

114
    @classmethod
115
116
117
118
119
120
121
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
122
123
124
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

125
126
127
128
129
130

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
131
132
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
133
134
135
136
137
138
139

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

140
        if self.base_layer.num_added_embeddings_per_partition > 0:
141
            # We can start adding lora weights
142
143
144
145
146
147
148
149
150
151
152
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
212
                shape[1], ].copy_(embeddings_tensor, non_blocking=True)
213
214
215
216
217
218
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
219
                    self.embeddings_tensors.shape[2],
220
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
221
                assert self.embeddings_weights is not None
222
223
224
225
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
226
227
        embeddings_indices = self.punica_wrapper.embeddings_indices
        indices = embeddings_indices[1].view_as(x)
228
229
230
231
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
232
        indices = embeddings_indices[0].view_as(x)
233
234
235
236
237
238
239
240
241
242
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
243
244
245
246
247
248
249
250
251
                full_lora_a_embeddings.shape[1],
                -1,
            )

        # Embedding layer only need expand op
        self.punica_wrapper.add_expand(full_output,
                                       full_lora_a_embeddings,
                                       self.lora_b_stacked,
                                       add_input=True)
252
253
        return full_output.view_as(full_output_org)

254
    @classmethod
255
256
257
258
259
260
261
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
262
263
        return type(source_layer) is VocabParallelEmbedding

264
265

class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
266
267
    """
    LoRA on top of ColumnParallelLinear layer.
268

269
270
    LoRA B is sliced for tensor parallelism.
    """
271
272
273
274

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
275
        self.tp_size = get_tensor_model_parallel_world_size()
276
277
278
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size_per_partition
        self.device = _get_lora_device(self.base_layer)
279
280

    def create_lora_weights(
281
282
283
284
285
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
286
287
288
289
290
        self.lora_config = lora_config
        self.tp_size = get_tensor_model_parallel_world_size()
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
291
292
293
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
294
            lora_a_output_size_per_partition,
295
            self.input_size,
296
            dtype=lora_config.lora_dtype,
297
            device=self.device,
298
299
300
301
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
302
            self.output_size,
303
304
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
305
            device=self.device,
306
        )
307
        self.output_dim = self.lora_b_stacked.shape[2]
308
309
310
311
312

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

313
314
315
316
317
318
319
320
321
322
323
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.output_dim
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_b = lora_b[:, start_idx:end_idx]
        return lora_b

324
325
326
327
328
329
330
331
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
332

333
        if self.tp_size > 1:
334
335
336
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

337
338
339
340
341
342
343
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

344
345
346
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
347
348
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        return output

    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
365
        output_parallel = self.apply(input_, bias)
366
367
368
369
370
371
372
373
374
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

375
    @classmethod
376
    @_not_fully_sharded_can_replace
377
378
379
380
381
382
383
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
384
385
386
387
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

388
389
390
391
392
393
394
395
396
397
398
399
400
401

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
402
403
404
405
406
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
407
        self.lora_config = lora_config
408
409
410
411
412
413
414
415
        n_slices = 2
        if not (len(self.base_layer.output_sizes) == n_slices
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()
416
417
418
419
420
        self.tp_rank = get_tensor_model_parallel_rank()

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
421
422
423
424
425

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
426
                lora_a_output_size_per_partition,
427
                self.input_size,
428
                dtype=lora_config.lora_dtype,
429
                device=self.device,
430
431
432
433
434
            ) for _ in range(n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
435
                self.output_size // 2,
436
437
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
438
                device=self.device,
439
440
441
442
443
444
445
446
447
448
            ) for _ in range(n_slices))

        self.output_dim = self.lora_b_stacked[0].shape[2]

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_b_stacked[1][index] = 0

449
450
451
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
452
453
        return lora_a

454
455
456
457
458
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        if lora_b[0] is None or lora_b[1] is None:
            return lora_b
459
460
461
462
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = [
463
464
            lora_b[0][:, start_idx:end_idx],
            lora_b[1][:, start_idx:end_idx],
465
466
467
        ]
        return lora_b

468
469
470
471
472
473
474
475
476
477
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
478
479
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)

496
497
498
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
499
500
501
        self.punica_wrapper.add_lora_packed_nslice(
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
            (self.output_dim, self.output_dim))
502
503
        return output

504
    @classmethod
505
    @_not_fully_sharded_can_replace
506
507
508
509
510
511
512
513
514
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is MergedColumnParallelLinear
                and len(packed_modules_list) == 2)
515

516
517

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
518
    """
519
520
521
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.
522

523
    During inference with Tensor Parallel, the weights of lora_b
524
    must be accurately partitioned according to the respective ranks.
525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)

542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

560
561
562
563
564
565
566
567
568
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        if self.tp_size > 1:
569
570
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
571
572
573
574
575
576
577
578
579

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    @classmethod
580
    @_not_fully_sharded_can_replace
581
582
583
584
585
586
587
588
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
603
604
605
606
607
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
608
        self.lora_config = lora_config
609
        self.tp_size = get_tensor_model_parallel_world_size()
610
        self.tp_rank = get_tensor_model_parallel_rank()
611
612
613
614
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
615
616
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
617

618
619
620
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
621
622
623
624
625
        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
626
                lora_a_output_size_per_partition,
627
                self.input_size,
628
                dtype=lora_config.lora_dtype,
629
                device=self.device,
630
631
632
633
            ),
            torch.zeros(
                max_loras,
                1,
634
                lora_a_output_size_per_partition,
635
                self.input_size,
636
                dtype=lora_config.lora_dtype,
637
                device=self.device,
638
639
640
641
            ),
            torch.zeros(
                max_loras,
                1,
642
                lora_a_output_size_per_partition,
643
                self.input_size,
644
                dtype=lora_config.lora_dtype,
645
                device=self.device,
646
647
648
649
650
651
652
653
654
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
655
                device=self.device,
656
657
658
659
660
661
662
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
663
                device=self.device,
664
665
666
667
668
669
670
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
671
                device=self.device,
672
673
674
            ),
        )

675
676
677
678
679
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
680
681
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
682
        # lazily initialized.
683
        self.indices: torch.Tensor
684
        self.indices_len: List[int]
685
686
687
688
689
690
691
692
693

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[1][index] = 0
        self.lora_a_stacked[2][index] = 0
        self.lora_b_stacked[2][index] = 0

694
695
696
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
697
698
        return lora_a

699
700
701
702
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        lora_b_q, lora_b_k, lora_b_v = None, None, None
703
704
705
        if lora_b[0] is not None:
            lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                 self.q_shard_id:self.q_proj_shard_size *
706
                                 (self.q_shard_id + 1), ]
707
708
709
        if lora_b[1] is not None:
            lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
710
                                 (self.kv_shard_id + 1), ]
711
712
713
        if lora_b[2] is not None:
            lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
714
                                 (self.kv_shard_id + 1), ]
715
716
717
        lora_b = [lora_b_q, lora_b_k, lora_b_v]
        return lora_b

718
719
720
721
722
723
724
725
726
727
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        if lora_b[0] is not None:
            lora_b_q = lora_b[0]
            self.lora_b_stacked[0][
                index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                    lora_b_q.T, non_blocking=True)
        if lora_b[1] is not None:
            lora_b_k = lora_b[1]
            self.lora_b_stacked[1][
                index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                    lora_b_k.T, non_blocking=True)
        if lora_b[2] is not None:
            lora_b_v = lora_b[2]
            self.lora_b_stacked[2][
                index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                    lora_b_v.T, non_blocking=True)
746
747
748
749
750
751
752
753
754
755
756
757
758
759

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
        if lora_a[2] is not None:
            self.lora_a_stacked[2][
                index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                    lora_a[2].T, non_blocking=True)

760
761
762
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
763
764
765
766
        self.punica_wrapper.add_lora_packed_nslice(output, x,
                                                   self.lora_a_stacked,
                                                   self.lora_b_stacked, 1.0,
                                                   self.output_slices)
767
768
        return output

769
    @classmethod
770
    @_not_fully_sharded_can_replace
771
772
773
774
775
776
777
778
779
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is QKVParallelLinear
                and len(packed_modules_list) == 3)
780

781
782
783
784
785
786

class RowParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
787
788
789
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)
790
791

    def create_lora_weights(
792
793
794
795
796
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
797
798
        self.lora_config = lora_config
        self.tp_rank = get_tensor_model_parallel_rank()
799
800
801
802
803
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
804
                self.input_size,
805
806
            ),
            dtype=lora_config.lora_dtype,
807
            device=self.device,
808
        )
809
810
811
812
813
        tp_size = get_tensor_model_parallel_world_size()
        lora_b_output_size_per_partition = (
            self.output_size if not lora_config.fully_sharded_loras else
            divide(self.output_size, tp_size))

814
815
816
817
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
818
                lora_b_output_size_per_partition,
819
820
821
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
822
            device=self.device,
823
824
825
826
827
828
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

829
830
831
832
833
834
835
836
837
838
839
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.input_size
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

840
841
842
843
844
845
846
847
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
848

849
        if self.base_layer.tp_size > 1:
850
851
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
852
853
854
855
856
857
858
859

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

860
861
    def apply(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x)
862
863
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
        return output

    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
889
        output_parallel = self.apply(input_parallel)
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
906
907
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
908

909
    @classmethod
910
    @_not_fully_sharded_can_replace
911
912
913
914
915
916
917
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
918
919
        return type(source_layer) is RowParallelLinear

920

921
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
922
923
924
925
926
927
928
929
930
931
932
933
934
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
935

936
937
938
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
939
940
941
942
943
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
944
945
946
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
947

948
    @property
949
950
    def logits_as_input(self):
        return self.base_layer.logits_as_input
951

952
953
954
955
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

956
957
958
959
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
960
961
962
963
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

964
965
966
967
    @property
    def use_gather(self):
        return self.base_layer.use_gather

968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
983
        if 32000 < self.base_layer.vocab_size > 128512:
984
            raise ValueError("When using LoRA, vocab size must be "
985
                             "32000 >= vocab_size <= 128512")
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1015
1016
1017
1018
1019
1020
1021
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1050
        lm_head: VocabParallelEmbedding,
1051
        embedding_bias: Optional[torch.Tensor] = None,
1052
    ) -> Optional[torch.Tensor]:
1053
        # Get the logits for the next tokens.
1054
        logits = lm_head.linear_method.apply(lm_head, hidden_states)
1055
1056
1057
1058
1059
1060
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1092
        indices_padded = self.punica_wrapper.sampler_indices_padded
1093
1094
1095
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1096
1097
1098
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1099
1100
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1101
1102
1103
1104
1105
1106
               lora_logits.shape[1], ] = lora_logits

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1107
1108
1109
1110
1111
1112
1113
1114

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1115
    @classmethod
1116
1117
1118
1119
1120
1121
1122
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1123
1124
        # Special handling for the LogitsProcessor.
        return False
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152


class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1153
1154
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1191
1192
            offsets=self.punica_wrapper.long_lora_indices,
        )
1193
1194
1195
1196
1197
1198

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1199
1200
1201
1202
1203
1204
1205
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1206
        """Returns True if the layer can be replaced by this LoRA layer."""
1207
1208
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1209
1210
1211

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()