punica_gpu.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""
Based on:
5
6
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
7
8
9
https://arxiv.org/abs/2310.18547
"""

10
from typing import final
11
12
13

import torch

14
from vllm.lora.layers import LoRAMapping
15
from vllm.triton_utils import HAS_TRITON, triton
Cyrus Leung's avatar
Cyrus Leung committed
16
from vllm.utils.math_utils import round_up
17
18

if HAS_TRITON:
19
20
21
22
23
24
25
26
    from vllm.lora.ops.triton_ops import (
        LoRAKernelMeta,
        fused_moe_lora,
        lora_expand,
        lora_shrink,
    )

from vllm import _custom_ops as ops
27
28
29

from .punica_base import PunicaWrapperBase

30

31
@final
32
class PunicaWrapperGPU(PunicaWrapperBase):
33
    """
34
35
    PunicaWrapperGPU is designed to manage and provide metadata for the punica
    kernel. The main function is to maintain the state information for
36
37
38
    Multi-LoRA, and to provide the interface for the punica triton kernel.
    """

39
40
41
42
    def __init__(
        self,
        max_num_batched_tokens: int,
        max_batches: int,
43
        device: torch.device | str,
44
45
46
        **kwargs,
    ):
        PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
47

48
        self.max_loras = kwargs["max_loras"]
49

50
51
52
        self.token_mapping_meta = LoRAKernelMeta.make(
            self.max_loras, max_num_batched_tokens, device=device
        )
53

54
55
56
        self.prompt_mapping_meta = LoRAKernelMeta.make(
            self.max_loras, max_batches, device=device
        )
57

58
59
60
    def update_metadata(
        self,
        mapping: LoRAMapping,
61
        lora_index_to_id: list[int | None],
62
63
64
65
66
        max_loras: int,
        vocab_size: int,
        extra_vocab_size: int,
        **kwargs,
    ):
67
        self.is_prefill = mapping.is_prefill
68
69
70
        self._update_base_metadata(
            mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
        )
71

72
73
74
        # Prepare cuda kernel metadata tensors
        self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
        self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
75

76
77
78
79
80
81
82
83
    def add_shrink(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        scale: float,
        **kwargs,
    ):
84
85
        """
        Performs GEMM  for multiple slices of lora_a.
86

87
88
89
        Semantics:
        for i in range(len(lora_a_stacked)):
            y[i] += (x @ lora_a_stacked[i]) * scale
90

91
        Args:
92
            y (torch.Tensor): Output tensors
93
            x (torch.Tensor): Input tensor
94
            lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
95
96
97
98
            scale (float): Scaling factor for the operation
        """

        x = x.view(-1, x.shape[-1])
99
100
101
102
103
104
105
        lora_shrink(
            x,
            lora_a_stacked,
            y,
            *self.token_mapping_meta.meta_args(x.size(0)),
            scale,
        )
106

107
108
109
110
111
112
113
114
115
116
    def add_expand(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_b_stacked: tuple[torch.Tensor, ...],
        output_slices: tuple[int, ...],
        offset_start: int = 0,
        add_inputs=True,
        **kwargs,
    ) -> None:
117
        """
118
        Performs GEMM for multiple slices of lora_b.
119

120
121
122
        Semantics:
            for i in range(len(lora_b_stacked)):
                slice = output_slices[i]
123
                y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
124
                offset += slice
125

126
127
        Args:
            y (torch.Tensor): Output tensor.
128
            x (torch.Tensor): Input tensors
129
130
            lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
            output_slices (tuple[int, ...]): Every slice's size
Jee Jee Li's avatar
Jee Jee Li committed
131
            add_inputs (bool): Defaults to True.
132
133
134
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
135

136
137
138
139
140
141
142
143
144
145
146
147
148
        assert x.ndim == 3
        assert x.size(0) == len(output_slices)
        num_tokens = x.size(1)  # first dimension is the num slices

        lora_expand(
            x,
            lora_b_stacked,
            y,
            *self.token_mapping_meta.meta_args(num_tokens),
            offset_start=offset_start,
            add_inputs=True,
        )

149
150
        y = y.view_as(y_org)

151
152
153
154
155
156
157
158
    def add_lora_embedding(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_b_stacked: torch.Tensor,
        add_inputs: bool = True,
        **kwargs,
    ) -> None:
159
160
161
162
163
164
165
166
167
168
        """
        Applies lora  specifically for VocabParallelEmbeddingWithLoRA.

        Semantics:
            y += x @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_b_stacked (torch.Tensor): lora_b's weights.
169
            add_inputs (bool): Default to True.
170
171
        """

172
173
        lora_expand(
            x.unsqueeze(dim=0),
174
            (lora_b_stacked,),
175
176
177
178
179
            y,
            *self.token_mapping_meta.meta_args(x.size(0)),
            offset_start=0,
            add_inputs=add_inputs,
        )
180

181
182
183
184
185
186
187
188
189
    def add_lora_linear(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        lora_b_stacked: tuple[torch.Tensor, ...],
        scale: float,
        output_slices: tuple[int, ...],
        *,
190
        buffer: torch.Tensor | None = None,
191
192
        **kwargs,
    ) -> None:
193
        """
194
        Applicable to linear-related lora.
195
196
197
198
199
200
201
202

        Semantics:
            for i in range(len(lora_a_stacked)):
                y[i] += (
                    x[i].unsqueeze(0)
                    @ lora_a_stacked[indices[i], layer_idx, :, :]
                    @ lora_b_stacked[indices[i], layer_idx, :, :]
                    * scale
203
                    ).squeeze(0)
204
205
206
        Args:
            y (torch.Tensor): Output tensor. Will be changed in-place.
            x (torch.Tensor): Input tensor
207
208
            lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
            lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
209
            scale (float): Scaling factor.
210
            output_slices (tuple[int, ...]): Every slice's size.
211
            buffer (Optional[torch.Tensor]): Defaults to None.
212
213
214
215
        """

        assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)

216
217
218
219
220
221
222
223
224
225
226
227
        assert buffer is None, (
            "To minimize overhead, the buffer should be created by "
            ".add_lora_linear() instead of being passed in."
        )
        r = lora_b_stacked[0].size(-1)
        # We set the buffer to be float32 by default, refer to:
        # https://github.com/triton-lang/triton/issues/1387
        # Note: buffer is zeroed inside the shrink op
        buffer = torch.empty(
            (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
        )

228
229
230
231
232
        self.add_shrink(
            buffer,  # type: ignore
            x,
            lora_a_stacked,
            scale,
233
234
            **kwargs,
        )
235
236
237
238
239
240
        self.add_expand(
            y,
            buffer,  # type: ignore
            lora_b_stacked,
            output_slices,
            add_inputs=True,
241
242
243
244
245
246
247
248
249
250
251
            **kwargs,
        )

    def add_lora_logits(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: torch.Tensor,
        lora_b_stacked: torch.Tensor,
        scale,
        *,
252
        buffer: torch.Tensor | None = None,
253
254
        **kwargs,
    ) -> None:
255
256
        """
        Applies lora  specifically for LogitsProcessorWithLoRA.
257

258
259
260
261
262
263
264
265
        Semantics:
            buffer = (x @ lora_a_stacked) * scale
            y += buffer @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_a_stacked (torch.Tensor): lora_a's weights.
Jee Jee Li's avatar
Jee Jee Li committed
266
            lora_b_stacked (torch.Tensor): lora_b's weights.
267
            scale (float): Scaling factor.
Jee Jee Li's avatar
Jee Jee Li committed
268
            buffer (Optional[torch.Tensor]): Default to None.
269
270
271
272
273
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
        x = x.view(-1, x.shape[-1])
        r = lora_b_stacked.size(-1)
274
275
276
277
278
279
280
281
282

        assert buffer is None, (
            "To minimize overhead, the buffer should be created by "
            ".add_lora_linear() instead of being passed in."
        )
        # We set the buffer to be float32 by default, refer to:
        # https://github.com/triton-lang/triton/issues/1387
        # Note: buffer is zeroed inside the shrink op
        buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
283

284
285
286
287
288
289
290
        lora_shrink(
            x,
            [lora_a_stacked],
            buffer.unsqueeze(dim=0),
            *self.prompt_mapping_meta.meta_args(x.size(0)),
            scale,
        )
291

292
293
294
295
296
297
298
        lora_expand(
            buffer.unsqueeze(dim=0),
            [lora_b_stacked],
            y,
            *self.prompt_mapping_meta.meta_args(buffer.size(0)),
            add_inputs=True,
        )
299
        y = y.view_as(y_org)
300
301
302
303
304
305
306
307

    def moe_lora_align_block_size(
        self,
        topk_ids: torch.Tensor,
        num_tokens: int,
        block_size: int,
        num_experts: int,
        max_loras: int,
308
        adapter_enabled: torch.Tensor,
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        expert_map: torch.Tensor | None = None,
        pad_sorted_ids: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Aligns tokens and experts into block-sized chunks for LoRA-based
        mixture-of-experts (MoE) execution.
        """
        max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
        if pad_sorted_ids:
            max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
        sorted_ids = torch.empty(
            (max_loras * max_num_tokens_padded,),
            dtype=torch.int32,
            device=topk_ids.device,
        )
        max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
        # Expert ids must be set default to -1 to prevent a blank block
        expert_ids = torch.empty(
            (max_loras * max_num_m_blocks,),
            dtype=torch.int32,
            device=topk_ids.device,
        )
        num_tokens_post_pad = torch.empty(
            (max_loras), dtype=torch.int32, device=topk_ids.device
        )

335
        (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
336
337
338
339
340
341
342
343
344
            num_tokens
        )

        ops.moe_lora_align_block_size(
            topk_ids,
            token_lora_mapping,
            num_experts,
            block_size,
            max_loras,
345
346
            max_num_tokens_padded,
            max_num_m_blocks,
347
348
349
            sorted_ids,
            expert_ids,
            num_tokens_post_pad,
350
351
            adapter_enabled,
            lora_ids,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        )
        if expert_map is not None:
            expert_ids = expert_map[expert_ids]

        return sorted_ids, expert_ids, num_tokens_post_pad

    def add_lora_fused_moe(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: list[torch.Tensor],
        lora_b_stacked: list[torch.Tensor],
        topk_weights: torch.Tensor,
        sorted_token_ids: torch.Tensor,
        expert_ids: torch.Tensor,
        num_tokens_post_padded: torch.Tensor,
        max_lora_rank: int,
        top_k_num: int,
        config,
371
        adapter_enabled: torch.Tensor,
372
373
374
375
376
        mul_routed_weight=False,
    ):
        """
        Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
        """
377
        (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
378
379
380
381
382
383
384
385
386
387
388
        fused_moe_lora(
            y,
            x,
            lora_a_stacked,
            lora_b_stacked,
            topk_weights,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
            max_lora_rank,
            top_k_num,
389
390
            lora_ids,
            adapter_enabled,
391
392
393
394
            config["BLOCK_SIZE_M"],
            config["BLOCK_SIZE_N"],
            config["BLOCK_SIZE_K"],
            config["GROUP_SIZE_M"],
395
            config.get("SPLIT_K", 1),
396
397
            mul_routed_weight,
        )