layers.py 40.6 KB
Newer Older
1
# pylint: disable=unused-argument
2
import inspect
3
4
import math
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
6
7
8
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
20
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
21
                                               MergedColumnParallelLinear,
22
                                               QKVParallelLinear,
23
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
from vllm.model_executor.layers.vocab_parallel_embedding import (
26
    ParallelLMHead, VocabParallelEmbedding)
27
28
29
30
31

if TYPE_CHECKING:
    pass


32
33
34
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
35
    # unquantizedLinear
36
37
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
Jee Li's avatar
Jee Li committed
38
39
40
41
42
43
44
45
    # GPTQ/AWQ/SqueezeLLM
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
46
47


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def _apply_lora(
    x: torch.Tensor,
    lora_a_stacked: torch.Tensor,
    lora_b_stacked: torch.Tensor,
    indices: torch.Tensor,
    output: torch.Tensor,
):
    """Applies lora to each input.

    This method applies all loras to each input. It uses the
    indices vector to determine which lora yields the
    correct output. An index of -1 means no lora should be
    applied. This method adds the final lora results to the
    output.

    Input shapes:
        x:               (batch_size, hidden_dim)
        lora_a_stacked:  (num_loras, lora_rank, hidden_dim)
        lora_b_stacked:  (num_loras, output_dim, lora_rank)
        indices:         (batch_size)
        output:          (batch_size, output_dim)
    """
    org_output = output
    x = x.view(-1, x.shape[-1])
    output = output.view(-1, output.shape[-1])
    indices = indices.view(-1)
    add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
    return output.view_as(org_output)


def _apply_lora_packed_nslice(
    x: torch.Tensor,
    lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    indices: torch.Tensor,
    output: torch.Tensor,
    output_slices: Tuple[int, ...],
):
    """Applies lora to each input.

    This method applies all loras to each input. It uses the
    indices vector to determine which lora yields the
    correct output. An index of -1 means no lora should be
    applied. This method adds the final lora results to the
    output.

    This method is used for layers that are composed of multiple sublayers
    (slices) packed together.

    Input shapes:
        x:                 (batch_size, hidden_dim)
        lora_a_stacked:    3 element tuple of (num_loras, lora_rank, hidden_dim)
        lora_b_stacked:    3 element tuple of (num_loras, output_dim, lora_rank)
        indices:           (batch_size)
        output:            (batch_size, q_slice_size + 2*kv_slice_size)
103
104
        output_slices:     n-1 element tuple of (slice_size...),
                           where n is number of slices
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """
    org_output = output
    x = x.view(-1, x.shape[-1])
    output = output.view(-1, output.shape[-1])
    indices = indices.view(-1)
    offset_left = 0
    for slice_idx in range(len(output_slices)):
        add_lora_slice(output, x, lora_a_stacked[slice_idx],
                       lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
                       output_slices[slice_idx])
        offset_left += output_slices[slice_idx]
    return output.view_as(org_output)


@dataclass
class LoRAMapping:
    # Per every token in input_ids:
    index_mapping: Tuple[int, ...]
    # Per sampled token:
    prompt_mapping: Tuple[int, ...]

    def __post_init__(self):
        self.index_mapping = tuple(self.index_mapping)
        self.prompt_mapping = tuple(self.prompt_mapping)


class BaseLayerWithLoRA(nn.Module):

133
134
135
136
137
    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
        base_indices: torch.Tensor,
        sampler_indices: torch.Tensor,
        sampler_indices_padded: torch.Tensor,
        embeddings_indices: torch.Tensor,
        indices_len: List[int],
    ):
        """Sets the mapping indices."""
        ...

166
167
168
169
170
171
172
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

        lora_vocab_start_idx = self.base_layer.org_vocab_size
        weights_idx = None
        if self.base_layer.vocab_end_index > lora_vocab_start_idx:
            # We can start adding lora weights
            weights_idx = max(
                lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
            self.embeddings_slice = (self.base_layer.vocab_start_index -
                                     self.base_layer.org_vocab_size +
                                     weights_idx,
                                     self.base_layer.vocab_end_index -
                                     self.base_layer.org_vocab_size)
            self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
            self.embeddings_weights.fill_(0)
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )
        self.indices: Optional[torch.Tensor] = None
        self.indices_len: Optional[List[int]] = None
        self.embeddings_indices = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1]].copy_(embeddings_tensor, non_blocking=True)
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
                    self.embeddings_tensors.shape[2]
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def set_mapping(
        self,
        base_indices: torch.Tensor,
        sampler_indices: torch.Tensor,
        sampler_indices_padded: torch.Tensor,
        embeddings_indices: torch.Tensor,
        indices_len: List[int],
    ):
        self.indices = base_indices
        self.embeddings_indices = embeddings_indices
        self.indices_len = indices_len

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
286
287
        embedding_len = self.indices_len[3]
        indices = self.embeddings_indices[1][:embedding_len].view_as(x)
288
289
290
291
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
292
        indices = self.embeddings_indices[0][:embedding_len].view_as(x)
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        full_output = self.base_layer.forward(
            x.add_(indices * added_tokens_mask))

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
                full_lora_a_embeddings.shape[1], -1)
        bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
             self.indices[:self.indices_len[0]], 0, 1.0)
        return full_output.view_as(full_output_org)

308
309
310
311
312
313
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is VocabParallelEmbedding

314
315
316
317
318
319

class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
320
        self.tp_size = get_tensor_model_parallel_world_size()
321
322
323
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size_per_partition
        self.device = _get_lora_device(self.base_layer)
324
325
326
327
328
329
330
331
332
333

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
            lora_config.max_lora_rank,
334
            self.input_size,
335
            dtype=lora_config.lora_dtype,
336
            device=self.device,
337
338
339
340
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
341
            self.output_size,
342
343
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
344
            device=self.device,
345
346
347
348
        )

        self.indices: Optional[torch.Tensor] = None
        self.indices_len: Optional[List[int]] = None
349
        self.output_dim = self.lora_b_stacked.shape[2]
350
351
352
353
354
355
356
357
358
359
360
361
362

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
363
364
365
366
367
368
        if self.tp_size > 1:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_dim
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    def set_mapping(
        self,
        base_indices: torch.Tensor,
        sampler_indices: torch.Tensor,
        sampler_indices_padded: torch.Tensor,
        embeddings_indices: torch.Tensor,
        indices_len: List[int],
    ):
        self.indices = base_indices
        self.indices_len = indices_len

    def apply_weights(self, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.linear_method.apply_weights(
390
            self.base_layer, x, bias)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        _apply_lora(
            x,
            self.lora_a_stacked,
            self.lora_b_stacked,
            self.indices[:self.indices_len[0]],
            output,
        )
        return output

    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output_parallel = self.apply_weights(input_, bias)
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

424
425
426
427
428
429
430
431
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:
        n_slices = 2
        if not (len(self.base_layer.output_sizes) == n_slices
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_config.max_lora_rank,
464
                self.input_size,
465
                dtype=lora_config.lora_dtype,
466
                device=self.device,
467
468
469
470
471
            ) for _ in range(n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
472
                self.output_size // 2,
473
474
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
475
                device=self.device,
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            ) for _ in range(n_slices))

        self.indices: Optional[torch.Tensor] = None
        self.output_dim = self.lora_b_stacked[0].shape[2]

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_b_stacked[1][index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_dim
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[0][:,
                               start_idx:end_idx], lora_b[1][:,
                                                             start_idx:end_idx]

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)

    def apply_weights(self, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.linear_method.apply_weights(
523
            self.base_layer, x, bias)
524
525
526
527
528
529
530
531
532
533
        _apply_lora_packed_nslice(
            x,
            self.lora_a_stacked,
            self.lora_b_stacked,
            self.indices[:self.indices_len[0]],
            output,
            (self.output_dim, self.output_dim),
        )
        return output

534
535
536
537
538
539
540
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is MergedColumnParallelLinear and len(
            packed_modules_list) == 2

541
542

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    """
    ColumnParallelLinear layer that is specifically designed for  
    qkv_proj. Certain models, such as chtglm3 and baichuan-7b,  
    only contains a single LoRA within their qkv_proj layer. 

    During inference with Tensor Parallel, the weights of lora_b 
    must be accurately partitioned according to the respective ranks.
    
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        if self.tp_size > 1:
            tp_rank = get_tensor_model_parallel_rank()
            self.q_shard_id = tp_rank
            self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
            lora_b_q = lora_b[:, self.q_proj_shard_size *
                              self.q_shard_id:self.q_proj_shard_size *
                              (self.q_shard_id + 1)]
            k_offset = self.q_proj_total_size
            lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
                              self.kv_shard_id:k_offset +
                              self.kv_proj_shard_size * (self.kv_shard_id + 1)]
            v_offset = k_offset + self.kv_proj_total_size
            lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
                              self.kv_shard_id:v_offset +
                              self.kv_proj_shard_size * (self.kv_shard_id + 1)]
            lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:
        self.tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas

        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
                lora_config.max_lora_rank,
641
                self.input_size,
642
                dtype=lora_config.lora_dtype,
643
                device=self.device,
644
645
646
647
648
            ),
            torch.zeros(
                max_loras,
                1,
                lora_config.max_lora_rank,
649
                self.input_size,
650
                dtype=lora_config.lora_dtype,
651
                device=self.device,
652
653
654
655
656
            ),
            torch.zeros(
                max_loras,
                1,
                lora_config.max_lora_rank,
657
                self.input_size,
658
                dtype=lora_config.lora_dtype,
659
                device=self.device,
660
661
662
663
664
665
666
667
668
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
669
                device=self.device,
670
671
672
673
674
675
676
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
677
                device=self.device,
678
679
680
681
682
683
684
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
685
                device=self.device,
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
            ),
        )

        self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
                              self.kv_proj_shard_size)
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
        self.indices_len: Optional[List[int]] = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[1][index] = 0
        self.lora_a_stacked[2][index] = 0
        self.lora_b_stacked[2][index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
            if lora_b[0] is not None:
                lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                     self.q_shard_id:self.q_proj_shard_size *
                                     (self.q_shard_id + 1)]
                self.lora_b_stacked[0][
                    index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                        lora_b_q.T, non_blocking=True)
            if lora_b[1] is not None:
                lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                     self.kv_shard_id:self.kv_proj_shard_size *
                                     (self.kv_shard_id + 1)]
                self.lora_b_stacked[1][
                    index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                        lora_b_k.T, non_blocking=True)
            if lora_b[2] is not None:
                lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                     self.kv_shard_id:self.kv_proj_shard_size *
                                     (self.kv_shard_id + 1)]
                self.lora_b_stacked[2][
                    index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                        lora_b_v.T, non_blocking=True)
        else:
            if lora_b[0] is not None:
                self.lora_b_stacked[0][
                    index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                        lora_b[0].T, non_blocking=True)
            if lora_b[1] is not None:
                self.lora_b_stacked[1][
                    index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                        lora_b[1].T, non_blocking=True)
            if lora_b[2] is not None:
                self.lora_b_stacked[2][
                    index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
                        lora_b[2].T, non_blocking=True)

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
        if lora_a[2] is not None:
            self.lora_a_stacked[2][
                index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                    lora_a[2].T, non_blocking=True)

    def apply_weights(self, x: torch.Tensor,
                      bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.linear_method.apply_weights(
764
            self.base_layer, x, bias)
765
766
767
768
769
770
771
772
773
774
        _apply_lora_packed_nslice(
            x,
            self.lora_a_stacked,
            self.lora_b_stacked,
            self.indices[:self.indices_len[0]],
            output,
            self.output_slices,
        )
        return output

775
776
777
778
779
780
781
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 3

782
783
784
785
786
787

class RowParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
788
789
790
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)
791
792
793
794
795
796
797
798
799
800
801

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
802
                self.input_size,
803
804
            ),
            dtype=lora_config.lora_dtype,
805
            device=self.device,
806
807
808
809
810
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
811
                self.output_size,
812
813
814
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
815
            device=self.device,
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        )
        self.indices: Optional[torch.Tensor] = None
        self.indices_len: Optional[List[int]] = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        if self.base_layer.tp_size > 1:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
834
            shard_size = self.input_size
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_a = lora_a[start_idx:end_idx, :]

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    def set_mapping(
        self,
        base_indices: torch.Tensor,
        sampler_indices: torch.Tensor,
        sampler_indices_padded: torch.Tensor,
        embeddings_indices: torch.Tensor,
        indices_len: List[int],
    ):
        self.indices = base_indices
        self.indices_len = indices_len

    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.linear_method.apply_weights(
859
            self.base_layer, x)
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
        _apply_lora(
            x,
            self.lora_a_stacked,
            self.lora_b_stacked,
            self.indices[:self.indices_len[0]],
            output,
        )
        return output

    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
        output_parallel = self.apply_weights(input_parallel)
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
909
910
911

        return self.base_layer.weight if hasattr(
            self.base_layer, "weight") else self.base_layer.qweight
912

913
914
915
916
917
918
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is RowParallelLinear

919

920
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
921
922
923

    def __init__(
        self,
924
        base_layer: LogitsProcessor,
925
926
927
928
929
930
931
932
933
934
        hidden_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device

935
    @property
936
937
    def logits_as_input(self):
        return self.base_layer.logits_as_input
938

939
940
941
942
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

943
944
945
946
    @property
    def scale(self):
        return self.base_layer.scale

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
962
        if 32000 < self.base_layer.vocab_size > 128512:
963
            raise ValueError("When using LoRA, vocab size must be "
964
                             "32000 >= vocab_size <= 128512")
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
        self.indices = None
        self.indices_padded = None
        self.indices_len = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def set_mapping(
        self,
        base_indices: torch.Tensor,
        sampler_indices: torch.Tensor,
        sampler_indices_padded: torch.Tensor,
        embeddings_indices: torch.Tensor,
        indices_len: List[int],
    ):
        self.indices = sampler_indices
        self.indices_padded = sampler_indices_padded
        self.indices_len = indices_len

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        embedding: torch.Tensor,
        embedding_bias: Optional[torch.Tensor] = None,
1039
    ) -> Optional[torch.Tensor]:
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        # Get the logits for the next tokens.
        logits = torch.matmul(hidden_states, embedding.t())
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
        ).index_select(0,
                       self.indices_padded[:self.indices_len[2]]).nan_to_num_(
                           nan=float("-inf"),
                           posinf=float("inf"),
                           neginf=float("-inf")))
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
               lora_logits.shape[1]] = lora_logits

        _apply_lora(
            hidden_states,
            self.lora_a_stacked,
            self.lora_b_stacked,
            self.indices[:self.indices_len[1]],
            logits,
        )

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]

        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    @classmethod
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        # Special handling for the LogitsProcessor.
        return False


_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
    cls
    for cls in globals().values() if inspect.isclass(cls)
    and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}


def from_layer(layer: nn.Module,
               max_loras: int,
               lora_config: LoRAConfig,
               packed_modules_list: List,
               model_config: Optional[PretrainedConfig] = None) -> nn.Module:
    for lora_cls in _all_lora_classes:
        if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
                                      model_config):
            ret = lora_cls(layer)
1112
1113
1114
1115
1116
            ret.create_lora_weights(max_loras, lora_config, model_config)
            return ret
    return layer


1117
1118
def from_layer_logits_processor(
    layer: LogitsProcessor,
1119
1120
1121
1122
    lm_head: ParallelLMHead,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: Optional[PretrainedConfig] = None,
1123
1124
1125
) -> LogitsProcessorWithLoRA:
    ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
                                  lm_head.weight.dtype, lm_head.weight.device)
1126
1127
    ret.create_lora_weights(max_loras, lora_config, model_config)
    return ret