punica_base.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). 
Punica: Multi-Tenant LoRA Serving. 
https://arxiv.org/abs/2310.18547
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from .utils import compute_meta, convert_mapping

if TYPE_CHECKING:
    # avoid circuit import
    from vllm.lora.layers import LoRAMapping
    from vllm.lora.models import LongContextLoRAContext


class PunicaWrapperABC(ABC):
    """
    PunicaWrapper ABC.
    """

    @abstractmethod
    def update_metadata(
        self,
        mapping: "LoRAMapping",
        lora_index_to_id: List[Optional[int]],
        max_loras: int,
        vocab_size: int,
        extra_vocab_size: int,
        long_lora_context: Optional["LongContextLoRAContext"] = None,
        **kwargs,
    ) -> None:
        """
        Update the lora-related metadata
        """
        raise NotImplementedError

    @abstractmethod
    def add_shrink(
        self,
        y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
        x: torch.Tensor,
        lora_a_stacked: Tuple[torch.Tensor, ...],
        scale: float,
        **kwargs,
51
    ) -> Optional[torch.Tensor]:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        """
        Performs GEMM  for multiple slices of lora_a.
        """

        raise NotImplementedError

    @abstractmethod
    def add_expand(
        self,
        y: torch.Tensor,
        x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
        lora_b_stacked: Tuple[torch.Tensor, ...],
        lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
        output_slices: Tuple[int, ...],
        offset_start: int = 0,
67
        add_inputs=True,
68
        **kwargs,
69
    ) -> Optional[torch.Tensor]:
70
71
72
73
74
75
76
77
78
79
80
        """
        Performs GEMM and bias addition for multiple slices of lora_b.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora_embedding(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_b_stacked: torch.Tensor,
81
        add_inputs: bool = True,
82
        **kwargs,
83
    ) -> Optional[torch.Tensor]:
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Applies lora  specifically for VocabParallelEmbeddingWithLoRA, 
        and this layer only requires the expand operation.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora_linear(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: Tuple[torch.Tensor, ...],
                        lora_b_stacked: Tuple[torch.Tensor, ...],
                        lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
                        scale: float,
                        output_slices: Tuple[int, ...],
                        *,
                        buffer: Optional[Tuple[torch.Tensor, ...]] = None,
101
                        **kwargs) -> Optional[torch.Tensor]:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        """
        Applicable to linear-related lora. 
        """

        raise NotImplementedError

    @abstractmethod
    def add_lora_logits(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: torch.Tensor,
                        lora_b_stacked: torch.Tensor,
                        scale,
                        *,
                        buffer: Optional[torch.Tensor] = None,
117
                        **kwargs) -> Optional[torch.Tensor]:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        """
        Applies lora  specifically for LogitsProcessorWithLoRA.
        """
        raise NotImplementedError


class PunicaWrapperBase(PunicaWrapperABC):
    """
    PunicaWrapperBase is designed to manage and provide metadata for the punica 
    kernel. The main function is to maintain the state information for 
    Multi-LoRA, and to provide the interface for the punica.
    """

    def __init__(self, max_num_batched_tokens: int, max_batches: int,
                 device: Union[torch.device, str], **kwargs):
        self._token_lora_indices = torch.empty(max_num_batched_tokens,
                                               dtype=torch.long,
                                               device=device)
        self._sampler_indices = torch.empty(max_num_batched_tokens,
                                            dtype=torch.long,
                                            device=device)
        self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
                                                   dtype=torch.long,
                                                   device=device)
        self._embeddings_indices = torch.empty(2,
                                               max_num_batched_tokens,
                                               dtype=torch.long,
                                               device=device)
        self._long_lora_indices = torch.empty(max_num_batched_tokens,
                                              dtype=torch.long,
                                              device=device)

150
        # 5 is the number of indices tensors.
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        # base_indices, sampler_indices, sampler_indices_padded,
        # embeddings_indices,long_lora_indices
        self.indices_len: List[Optional[int]] = [None] * 5
        # these attributes are the information required for sgmv kernel
        self._seq_start_locs = torch.empty(max_batches,
                                           dtype=torch.long,
                                           device=device)
        self._seq_lengths = torch.empty(max_batches,
                                        dtype=torch.long,
                                        device=device)
        self._lora_indices_per_batch = torch.empty(max_batches,
                                                   dtype=torch.long,
                                                   device=device)
        self.device: torch.device = device
        self.max_length: int = 0
        self.token_nums: int = 0
        self.batch_size: int = -1
        self.is_prefill = False
        self.no_lora = False

    def _update_base_metadata(
        self,
        mapping: "LoRAMapping",
        lora_index_to_id: List[Optional[int]],
        max_loras: int,
        vocab_size: int,
        extra_vocab_size: int,
        long_lora_context: Optional["LongContextLoRAContext"] = None,
    ):
        (
            base_indices,
            sampler_indices,
            sampler_indices_padded,
            embeddings_indices,
            long_lora_offsets_tensor,
            indices_len,
        ) = convert_mapping(
            mapping,
            lora_index_to_id,
            max_loras,
            vocab_size,
            extra_vocab_size,
            self.device,
            long_lora_context,
        )
        self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
        self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
        self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
            sampler_indices_padded)
        self._embeddings_indices[:embeddings_indices.
                                 shape[0], :embeddings_indices.shape[1]].copy_(
                                     embeddings_indices)
        if long_lora_offsets_tensor is not None:
            self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
                long_lora_offsets_tensor)
        else:
            self._long_lora_indices.zero_()
        self.indices_len[:] = indices_len

210
211
    def _update_prefill_metadata(self,
                                 token_lora_tensor: torch.Tensor) -> None:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

        (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
         batch_size, max_length, token_nums,
         no_lora) = compute_meta(token_lora_tensor)

        self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
            b_seq_start_tensor)
        self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
        self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
            lora_indices_tensor)
        self.batch_size = batch_size
        self.max_length = max_length
        self.token_nums = token_nums
        self.no_lora = no_lora

    def _apply_bias(
        self,
        indices: torch.Tensor,
        output: torch.Tensor,
        output_slices: Tuple[int, ...],
        lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
    ):
        """Applies bias to output

        Input shapes:
            lora_bias_stacked:      3 element tuple of (num_loras, output_dim)
            indices:           (batch_size)
            output:            (batch_size, q_slice_size + 2*kv_slice_size)
            output_slices:     n-1 element tuple of (slice_size...),
                            where n is number of slices
        """
        org_output = output
        output = output.view(-1, output.shape[-1])
        indices = indices.view(-1)

        offset_left = 0
        for slice_idx, slice in enumerate(output_slices):
            bias = lora_bias_stacked[slice_idx]
            if bias is not None:
                bias = bias.view(-1, bias.shape[-1])
                bias = bias[indices]
                bias[indices == -1] = 0
                output[:, offset_left:offset_left + slice] += bias
            offset_left += slice

        return output.view_as(org_output)

    @property
    def prefill_metadata(
        self
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
        """
        This property provides a convenient way to access the necessary 
        metadata for prefill-related  kernel computations.
            1. seq_start_locs: Tensor of sequence start positions.
            2. seq_lengths: Tensor of sequence lengths.
            3. lora_indices_per_batch: Tensor of lora indices, and an index of 
                -1 means no lora should be applied.
            4. batch_size: Batch size after clustering identical lora indices.
            5. max_length: The maximum sequence length in the batch.
            6. token_nums: The token numbers in the batch.
        """
        return (self._seq_start_locs[:self.batch_size],
                self._seq_lengths[:self.batch_size],
                self._lora_indices_per_batch[:self.batch_size],
                self.batch_size, self.max_length, self.token_nums)

    @property
    def token_lora_indices(self) -> torch.Tensor:
        """
        This property provides the lora indices corresponding to each token 
        in the batch. An index of -1 means no lora should be applied.
        """
        token_lora_len = self.indices_len[0]
        return self._token_lora_indices[:token_lora_len]

    @property
    def sampler_indices(self) -> torch.Tensor:
        """ 
        This property is used to access the lora indices specifically for 
        LogitsProcessorWithLoRA.
        """
        sampler_indices_len = self.indices_len[1]
        return self._sampler_indices[:sampler_indices_len]

    @property
    def sampler_indices_padded(self) -> torch.Tensor:
        """
        This property provides access to padded sampler indices.
        """
        indices_padded_len = self.indices_len[2]
        return self._sampler_indices_padded[:indices_padded_len]

    @property
    def embeddings_indices(self) -> torch.Tensor:
        """
        This property provides access to the indices used for lora embeddings, 
        specifically for VocabParallelEmbeddingWithLoRA.
        """
        embeddings_indices_len = self.indices_len[3]
        return self._embeddings_indices[:, :embeddings_indices_len]

    @property
    def long_lora_indices(self) -> torch.Tensor:
        """ 
        This property provides access to the indices used for long context 
318
        lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        """
        long_lora_len = self.indices_len[4]
        return self._long_lora_indices[:long_lora_len]

    def update_metadata(
            self,
            mapping: "LoRAMapping",
            lora_index_to_id: List[Optional[int]],
            max_loras: int,
            vocab_size: int,
            extra_vocab_size: int,
            long_lora_context: Optional["LongContextLoRAContext"] = None,
            **kwargs):

        self._update_base_metadata(mapping, lora_index_to_id, max_loras,
                                   vocab_size, extra_vocab_size,
                                   long_lora_context)
        if mapping.is_prefill:
            # Update metadata required for prefill-related operators.
338
            self._update_prefill_metadata(self.token_lora_indices)
339
340
341
342
343
344
345
            self.is_prefill = True
        else:
            self.is_prefill = False

    @abstractmethod
    def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
                   x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
346
                   scale: float, **kwargs) -> Optional[torch.Tensor]:
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        """
        Performs GEMM  for multiple slices of lora_a.

        Semantics:
        for i in range(len(lora_a_stacked)):
            y[i] += (x @ lora_a_stacked[i]) * scale
        
        Args:
            y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
            x (torch.Tensor): Input tensor
            lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
            scale (float): Scaling factor for the operation

        """
        # TODO: implement it based on torch ops
        raise NotImplementedError

    @abstractmethod
    def add_expand(self,
                   y: torch.Tensor,
                   x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
                   lora_b_stacked: Tuple[torch.Tensor, ...],
                   lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
                   output_slices: Tuple[int, ...],
                   offset_start: int = 0,
372
                   add_inputs=True,
373
                   **kwargs) -> Optional[torch.Tensor]:
374
375
376
377
        """
        Performs GEMM and bias addition for multiple slices of lora_b.
      
        Semantics:
378
            offset = offset_start
379
380
381
382
383
384
385
386
387
388
389
390
391
            for i in range(len(lora_b_stacked)):
                slice = output_slices[i]
                y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + 
                    lora_bias_stacked[i] 
                offset += slice
            
        Args:
            y (torch.Tensor): Output tensor.
            x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
            lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
            lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): 
                bias's weight
            output_slices (Tuple[int, ...]): Every slice's size
392
393
            offset_start (int): The starting position of y, defaults to 0
            add_inputs (bool):  Defaults to True.
394
395
396
397
398
399
400
401
402
403

        """
        # TODO: implement it based on torch ops
        raise NotImplementedError

    @abstractmethod
    def add_lora_embedding(self,
                           y: torch.Tensor,
                           x: torch.Tensor,
                           lora_b_stacked: torch.Tensor,
404
                           add_inputs: bool = True,
405
                           **kwargs) -> Optional[torch.Tensor]:
406
407
408
409
410
411
412
413
414
415
        """
        Applies lora  specifically for VocabParallelEmbeddingWithLoRA.
        and this layer only requires the expand operation.
        Semantics:
            y += x @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_b_stacked (torch.Tensor): lora_b's weights.
416
            add_inputs (bool): Default to True.
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        """
        # TODO: implement it based on torch ops
        raise NotImplementedError

    @abstractmethod
    def add_lora_linear(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: Tuple[torch.Tensor, ...],
                        lora_b_stacked: Tuple[torch.Tensor, ...],
                        lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
                        scale: float,
                        output_slices: Tuple[int, ...],
                        *,
                        buffer: Optional[Tuple[torch.Tensor, ...]] = None,
432
                        **kwargs) -> Optional[torch.Tensor]:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        """
        Applicable to linear-related lora. 

        Semantics:
            for i in range(len(lora_a_stacked)):
                y[i] += (
                    x[i].unsqueeze(0)
                    @ lora_a_stacked[indices[i], layer_idx, :, :]
                    @ lora_b_stacked[indices[i], layer_idx, :, :]
                    * scale
                    ).squeeze(0)+lora_bias_stacked[i]

        Args:
            y (torch.Tensor): Output tensor. Will be changed in-place.
            x (torch.Tensor): Input tensor
            lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
            lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
            lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
            scale (float): Scaling factor.
            output_slices (Tuple[int, ...]): Every slice's size.
            buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
        """
        # TODO: implement it based on torch ops
        raise NotImplementedError

    @abstractmethod
    def add_lora_logits(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: torch.Tensor,
                        lora_b_stacked: torch.Tensor,
                        scale,
                        *,
                        buffer: Optional[torch.Tensor] = None,
467
                        **kwargs) -> Optional[torch.Tensor]:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        """
        Applies lora  specifically for LogitsProcessorWithLoRA.
        
        Semantics:
            buffer = (x @ lora_a_stacked) * scale
            y += buffer @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_a_stacked (torch.Tensor): lora_a's weights.
            lora_b_stacked (torch.Tensor):lora_b's weights.
            scale (float): Scaling factor.
            buffer (Optional[torch.Tensor]):Default to None.
        """
        # TODO: implement it based on torch ops
        raise NotImplementedError