punica_gpu.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). 
Punica: Multi-Tenant LoRA Serving. 
https://arxiv.org/abs/2310.18547
"""

9
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
10
11
12

import torch

13
14
import vllm.envs as env
from vllm.lora.layers import LoRAMapping
15
16
17
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
18
19
20
21
22
23
24
25
26
    if env.VLLM_USE_V1:
        from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
                                                 v1_shrink)
    else:
        from vllm.lora.ops.triton_ops import bgmv_expand
        from vllm.lora.ops.triton_ops import bgmv_expand_slice
        from vllm.lora.ops.triton_ops import bgmv_shrink
        from vllm.lora.ops.triton_ops import sgmv_expand
        from vllm.lora.ops.triton_ops import sgmv_shrink
27
28
29

from .punica_base import PunicaWrapperBase

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
if TYPE_CHECKING:
    # avoid circuit import
    from vllm.lora.models import LongContextLoRAContext


class V1KernelMixin:

    def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
                          max_batches: int, device: Union[torch.device, str]):
        self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
                                                       max_num_batched_tokens,
                                                       device=device)
        self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
                                                        max_batches,
                                                        device=device)

    def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
                                     sampler_indices: torch.Tensor):
        self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
        self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)

    def _v1_apply_shrink(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        w_t_all: Tuple[torch.Tensor, ...],
        scale: float,
    ):
        v1_shrink(
            x,
            w_t_all,
            y,
            *self.token_mapping_v1_meta.meta_args(x.size(0)),
            scale,
        )

    def _v1_apply_expand(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        w_t_all: Tuple[torch.Tensor, ...],
        offset_start: int,
        add_inputs: bool,
    ):
        v1_expand(
            x,
            w_t_all,
            y,
            *self.token_mapping_v1_meta.meta_args(x.size(0)),
            offset_start=offset_start,
            add_inputs=add_inputs,
        )

83
84

@final
85
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
86
87
88
89
90
91
92
93
94
95
96
    """
    PunicaWrapperGPU is designed to manage and provide metadata for the punica 
    kernel. The main function is to maintain the state information for 
    Multi-LoRA, and to provide the interface for the punica triton kernel.
    """

    def __init__(self, max_num_batched_tokens: int, max_batches: int,
                 device: Union[torch.device, str], **kwargs):
        PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
                                   device)

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        self.max_loras = kwargs['max_loras']

        if env.VLLM_USE_V1:
            self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
                                   max_batches, device)

    def update_metadata(
            self,
            mapping: LoRAMapping,
            lora_index_to_id: List[Optional[int]],
            max_loras: int,
            vocab_size: int,
            extra_vocab_size: int,
            long_lora_context: Optional["LongContextLoRAContext"] = None,
            **kwargs):

        if env.VLLM_USE_V1:
            self.is_prefill = mapping.is_prefill
            self._update_base_metadata(mapping, lora_index_to_id, max_loras,
                                       vocab_size, extra_vocab_size,
                                       long_lora_context)
            self._v1_prepare_metadata_tensors(self.token_lora_indices,
                                              self.sampler_indices)
        else:
            # Forward to base class update_metadata
            PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
                                              max_loras, vocab_size,
                                              extra_vocab_size,
                                              long_lora_context, **kwargs)

127
    def _apply_shrink_prefill(
128
129
130
        self,
        y: torch.Tensor,
        x: torch.Tensor,
131
        w_t_all: Tuple[torch.Tensor, ...],
132
133
134
135
136
137
138
139
140
141
142
143
144
        scale: float,
    ):
        #No LoRA request, so return directly
        if self.no_lora:
            return
        sgmv_shrink(
            x,
            w_t_all,
            y,
            *self.prefill_metadata,
            scale,
        )

145
    def _apply_shrink_decode(
146
147
148
149
150
151
152
153
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        w_t_all: torch.Tensor,
        scale: float,
    ):
        bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)

154
    def _apply_expand_prefill(
155
156
157
        self,
        y: torch.Tensor,
        x: torch.Tensor,
158
        w_t_all: Tuple[torch.Tensor, ...],
159
        offset_start: int,
160
        add_inputs: bool,
161
162
163
164
165
    ):
        #No LoRA request, so return directly
        if self.no_lora:
            return

166
        sgmv_expand(
167
168
169
170
            x,
            w_t_all,
            y,
            *self.prefill_metadata,
171
172
            offset_start=offset_start,
            add_inputs=add_inputs,
173
174
        )

175
    def _apply_expand_decode(
176
177
178
179
180
181
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        w_t_all: torch.Tensor,
        y_offset: Optional[int],
        y_slice_size: Optional[int],
182
        add_inputs: bool,
183
184
    ):
        bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
185
                          y_slice_size, add_inputs)
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
                   x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
                   scale: float, **kwargs):
        """
        Performs GEMM  for multiple slices of lora_a.
        When `is_prefill is` true, it indicates that it is currently the
        prefill stage, and the `_shrink_prefill` function should be called.
        Otherwise, it is the decode stage, and the _shrink_decode function
        should be called.
            
        Semantics:
        for i in range(len(lora_a_stacked)):
            y[i] += (x @ lora_a_stacked[i]) * scale
        
        Args:
            y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
            x (torch.Tensor): Input tensor
            lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
            scale (float): Scaling factor for the operation
        """

        x = x.view(-1, x.shape[-1])
209

210
211
        if env.VLLM_USE_V1:
            self._v1_apply_shrink(y, x, lora_a_stacked, scale)  # type: ignore
212
        else:
213
214
215
216
217
218
219
220
221
222
223
224
            if self.is_prefill:
                # NOTE fused kernel
                self._apply_shrink_prefill(
                    y,  # type: ignore
                    x,
                    lora_a_stacked,
                    scale)
            else:
                # TODO fuse these kernels
                for slice_idx in range(len(lora_a_stacked)):
                    self._apply_shrink_decode(y[slice_idx], x,
                                              lora_a_stacked[slice_idx], scale)
225
226
227
228
229
230
231
232

    def add_expand(self,
                   y: torch.Tensor,
                   x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
                   lora_b_stacked: Tuple[torch.Tensor, ...],
                   lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
                   output_slices: Tuple[int, ...],
                   offset_start: int = 0,
233
                   add_inputs=True,
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                   **kwargs) -> None:
        """
        Performs GEMM and bias addition for multiple slices of lora_b.
      
        Semantics:
            for i in range(len(lora_b_stacked)):
                slice = output_slices[i]
                y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + 
                    lora_bias_stacked[i] 
                offset += slice
            
        Args:
            y (torch.Tensor): Output tensor.
            x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
            lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
            lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): 
                bias's weight
            output_slices (Tuple[int, ...]): Every slice's size
252
            add_inputs (bool):  Defaults to True.
253
254
255
256
257
258
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
        if lora_bias_stacked is not None:
            self._apply_bias(self.token_lora_indices, y, output_slices,
                             lora_bias_stacked)
259
260
261
262
263
264
265
266
267
268

        if env.VLLM_USE_V1:
            # TODO (varun): Profile with add_inputs = False. i.e. move the
            # addition out of the kernel
            self._v1_apply_expand(
                y,
                x,  # type: ignore
                lora_b_stacked,
                offset_start,
                add_inputs=True)
269
        else:
270
271
272
273

            if self.is_prefill:
                # NOTE fused kernel
                self._apply_expand_prefill(
274
                    y,
275
276
                    x,  # type: ignore
                    lora_b_stacked,
277
                    offset_start,
278
279
280
281
282
283
284
285
286
287
288
289
290
                    add_inputs=True)
            else:
                # TODO fuse these kernels
                for slice_idx in range(len(lora_b_stacked)):
                    self._apply_expand_decode(
                        y,
                        x[slice_idx],
                        lora_b_stacked[slice_idx],
                        offset_start,
                        output_slices[slice_idx],
                        add_inputs=add_inputs,
                    )
                    offset_start += output_slices[slice_idx]
291
292
293
294
295
296
        y = y.view_as(y_org)

    def add_lora_embedding(self,
                           y: torch.Tensor,
                           x: torch.Tensor,
                           lora_b_stacked: torch.Tensor,
297
                           add_inputs: bool = True,
298
299
300
301
302
303
304
305
306
307
308
                           **kwargs) -> None:
        """
        Applies lora  specifically for VocabParallelEmbeddingWithLoRA.

        Semantics:
            y += x @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_b_stacked (torch.Tensor): lora_b's weights.
309
            add_inputs (bool): Default to True.
310
311
        """

312
313
314
315
316
        if env.VLLM_USE_V1:
            self._v1_apply_expand(y,
                                  x.unsqueeze(dim=0), (lora_b_stacked, ),
                                  offset_start=0,
                                  add_inputs=add_inputs)
317
        else:
318
319
320
321
322
323
324
325
326
327
328
329
            if self.is_prefill:
                sgmv_expand(
                    x.unsqueeze(dim=0),
                    (lora_b_stacked, ),
                    y,
                    *self.prefill_metadata,
                    offset_start=0,
                    add_inputs=add_inputs,
                )
            else:
                bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
                            add_inputs)
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

    def add_lora_linear(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: Tuple[torch.Tensor, ...],
                        lora_b_stacked: Tuple[torch.Tensor, ...],
                        lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
                        scale: float,
                        output_slices: Tuple[int, ...],
                        *,
                        buffer: Optional[Tuple[torch.Tensor, ...]] = None,
                        **kwargs) -> None:
        """
        Applicable to linear-related lora. 

        Semantics:
            for i in range(len(lora_a_stacked)):
                y[i] += (
                    x[i].unsqueeze(0)
                    @ lora_a_stacked[indices[i], layer_idx, :, :]
                    @ lora_b_stacked[indices[i], layer_idx, :, :]
                    * scale
                    ).squeeze(0)+lora_bias_stacked[i]

        Args:
            y (torch.Tensor): Output tensor. Will be changed in-place.
            x (torch.Tensor): Input tensor
            lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
            lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
            lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
            scale (float): Scaling factor.
            output_slices (Tuple[int, ...]): Every slice's size.
            buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
        """

        assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
        if lora_bias_stacked is not None:
            assert len(lora_bias_stacked) == len(output_slices)
            y = self._apply_bias(self.token_lora_indices, y, output_slices,
                                 lora_bias_stacked)

        if buffer is None:
            r = lora_b_stacked[0].size(-1)
            # We set the buffer to be float32 by default ,refer to:
            # https://github.com/triton-lang/triton/issues/1387
375
            buffer = torch.zeros(  # type: ignore
376
377
378
379
                (len(output_slices), x.size(0), r),
                dtype=torch.float32,
                device=x.device,
            )
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        self.add_shrink(
            buffer,  # type: ignore
            x,
            lora_a_stacked,
            scale,
            **kwargs)
        self.add_expand(
            y,
            buffer,  # type: ignore
            lora_b_stacked,
            None,
            output_slices,
            add_inputs=True,
            **kwargs)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

    def add_lora_logits(self,
                        y: torch.Tensor,
                        x: torch.Tensor,
                        lora_a_stacked: torch.Tensor,
                        lora_b_stacked: torch.Tensor,
                        scale,
                        *,
                        buffer: Optional[torch.Tensor] = None,
                        **kwargs) -> None:
        """
        Applies lora  specifically for LogitsProcessorWithLoRA.
        
        Semantics:
            buffer = (x @ lora_a_stacked) * scale
            y += buffer @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_a_stacked (torch.Tensor): lora_a's weights.
            lora_b_stacked (torch.Tensor):lora_b's weights.
            scale (float): Scaling factor.
            buffer (Optional[torch.Tensor]):Default to None.
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
        x = x.view(-1, x.shape[-1])
        r = lora_b_stacked.size(-1)
        if buffer is None:
            # We set the buffer to be float32 by default ,refer to:
            # https://github.com/triton-lang/triton/issues/1387
            buffer = torch.zeros((x.size(0), r),
                                 dtype=torch.float32,
                                 device=x.device)
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

        if env.VLLM_USE_V1:
            v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
                      *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)

            v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
                      y,
                      *self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
                      add_inputs=True)
        else:

            # V0 LogitsProcessorWithLoRA always using bgmv.
            bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
            bgmv_expand(buffer,
                        lora_b_stacked,
                        y,
                        self.sampler_indices,
                        add_inputs=True)
447
        y = y.view_as(y_org)