"requirements/test/cuda.txt" did not exist on "d6c86d09aecb910fd336ba83ede70265ee81149a"
punica_gpu.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""
Based on:
5
6
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
7
8
9
https://arxiv.org/abs/2310.18547
"""

10
from typing import final
11
12
13

import torch

14
from vllm.lora.layers import LoRAMapping
15
from vllm.lora.utils import get_captured_lora_counts
16
from vllm.triton_utils import HAS_TRITON, triton
Cyrus Leung's avatar
Cyrus Leung committed
17
from vllm.utils.math_utils import round_up
18
19

if HAS_TRITON:
20
21
22
23
24
25
26
27
    from vllm.lora.ops.triton_ops import (
        LoRAKernelMeta,
        fused_moe_lora,
        lora_expand,
        lora_shrink,
    )

from vllm import _custom_ops as ops
28
29
30

from .punica_base import PunicaWrapperBase

31

32
@final
33
class PunicaWrapperGPU(PunicaWrapperBase):
34
    """
35
36
    PunicaWrapperGPU is designed to manage and provide metadata for the punica
    kernel. The main function is to maintain the state information for
37
38
39
    Multi-LoRA, and to provide the interface for the punica triton kernel.
    """

40
41
42
43
    def __init__(
        self,
        max_num_batched_tokens: int,
        max_batches: int,
44
        device: torch.device | str,
45
46
47
        **kwargs,
    ):
        PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
48

49
50
        self.lora_config = kwargs["lora_config"]
        self.max_loras = self.lora_config.max_loras
51

52
53
54
55
56
        # Compute captured LoRA counts for cudagraph specialization.
        captured_lora_counts = get_captured_lora_counts(
            self.max_loras, self.lora_config.specialize_active_lora
        )

57
        self.token_mapping_meta = LoRAKernelMeta.make(
58
59
60
61
            self.max_loras,
            max_num_batched_tokens,
            device=device,
            captured_lora_counts=captured_lora_counts,
62
        )
63

64
65
66
67
        # When speculative decoding is enabled, max_num_samples is
        # max_batches * (num_speculative_decoding_tokens + 1).
        # This line can be optimized by replacing max_num_batched_tokens
        # to  max_batches * (num_speculative_decoding_tokens + 1).
68
        self.prompt_mapping_meta = LoRAKernelMeta.make(
69
70
71
72
            self.max_loras,
            max_num_batched_tokens,
            device=device,
            captured_lora_counts=captured_lora_counts,
73
        )
74

75
76
77
    def update_metadata(
        self,
        mapping: LoRAMapping,
78
        lora_index_to_id: list[int | None],
79
80
81
82
        max_loras: int,
        vocab_size: int,
        **kwargs,
    ):
83
        self.is_prefill = mapping.is_prefill
84
        self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
85

86
87
88
        # Prepare cuda kernel metadata tensors
        self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
        self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
89

90
91
92
93
94
95
96
97
    def add_shrink(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        scale: float,
        **kwargs,
    ):
98
99
        """
        Performs GEMM  for multiple slices of lora_a.
100

101
102
103
        Semantics:
        for i in range(len(lora_a_stacked)):
            y[i] += (x @ lora_a_stacked[i]) * scale
104

105
        Args:
106
            y (torch.Tensor): Output tensors
107
            x (torch.Tensor): Input tensor
108
            lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
109
110
111
112
            scale (float): Scaling factor for the operation
        """

        x = x.view(-1, x.shape[-1])
113
114
115
116
        lora_shrink(
            x,
            lora_a_stacked,
            y,
117
118
119
            *self.token_mapping_meta.meta_args(
                x.size(0), self.lora_config.specialize_active_lora
            ),
120
121
            scale,
        )
122

123
124
125
126
127
128
129
130
131
132
    def add_expand(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_b_stacked: tuple[torch.Tensor, ...],
        output_slices: tuple[int, ...],
        offset_start: int = 0,
        add_inputs=True,
        **kwargs,
    ) -> None:
133
        """
134
        Performs GEMM for multiple slices of lora_b.
135

136
137
138
        Semantics:
            for i in range(len(lora_b_stacked)):
                slice = output_slices[i]
139
                y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
140
                offset += slice
141

142
143
        Args:
            y (torch.Tensor): Output tensor.
144
            x (torch.Tensor): Input tensors
145
146
            lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
            output_slices (tuple[int, ...]): Every slice's size
147
148
149
            add_inputs (bool): If True, add LoRA output to y; if False, write
                LoRA-only output to y (used for dual-stream when base and LoRA
                run on different CUDA streams). Defaults to True.
150
151
152
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
153

154
155
156
157
158
159
160
161
        assert x.ndim == 3
        assert x.size(0) == len(output_slices)
        num_tokens = x.size(1)  # first dimension is the num slices

        lora_expand(
            x,
            lora_b_stacked,
            y,
162
163
164
            *self.token_mapping_meta.meta_args(
                num_tokens, self.lora_config.specialize_active_lora
            ),
165
            offset_start=offset_start,
166
            add_inputs=add_inputs,
167
168
        )

169
170
        y = y.view_as(y_org)

171
172
173
174
175
176
177
178
    def add_lora_embedding(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_b_stacked: torch.Tensor,
        add_inputs: bool = True,
        **kwargs,
    ) -> None:
179
180
181
182
183
184
185
186
187
188
        """
        Applies lora  specifically for VocabParallelEmbeddingWithLoRA.

        Semantics:
            y += x @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_b_stacked (torch.Tensor): lora_b's weights.
189
            add_inputs (bool): Default to True.
190
191
        """

192
193
        lora_expand(
            x.unsqueeze(dim=0),
194
            (lora_b_stacked,),
195
            y,
196
197
198
            *self.token_mapping_meta.meta_args(
                x.size(0), self.lora_config.specialize_active_lora
            ),
199
200
201
            offset_start=0,
            add_inputs=add_inputs,
        )
202

203
204
205
206
207
208
209
210
211
    def add_lora_linear(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        lora_b_stacked: tuple[torch.Tensor, ...],
        scale: float,
        output_slices: tuple[int, ...],
        *,
212
        buffer: torch.Tensor | None = None,
213
214
        **kwargs,
    ) -> None:
215
        """
216
        Applicable to linear-related lora.
217
218
219
220
221
222
223
224

        Semantics:
            for i in range(len(lora_a_stacked)):
                y[i] += (
                    x[i].unsqueeze(0)
                    @ lora_a_stacked[indices[i], layer_idx, :, :]
                    @ lora_b_stacked[indices[i], layer_idx, :, :]
                    * scale
225
                    ).squeeze(0)
226
227
228
        Args:
            y (torch.Tensor): Output tensor. Will be changed in-place.
            x (torch.Tensor): Input tensor
229
230
            lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
            lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
231
            scale (float): Scaling factor.
232
            output_slices (tuple[int, ...]): Every slice's size.
233
            buffer (Optional[torch.Tensor]): Defaults to None.
234
235
236
237
        """

        assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)

238
239
240
241
242
243
244
245
246
247
248
        assert buffer is None, (
            "To minimize overhead, the buffer should be created by "
            ".add_lora_linear() instead of being passed in."
        )
        r = lora_b_stacked[0].size(-1)
        # We set the buffer to be float32 by default, refer to:
        # https://github.com/triton-lang/triton/issues/1387
        # Note: buffer is zeroed inside the shrink op
        buffer = torch.empty(
            (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
        )
249
        add_inputs = kwargs.pop("add_inputs", True)
250
251
252
253
254
        self.add_shrink(
            buffer,  # type: ignore
            x,
            lora_a_stacked,
            scale,
255
256
            **kwargs,
        )
257
258
259
260
261
        self.add_expand(
            y,
            buffer,  # type: ignore
            lora_b_stacked,
            output_slices,
262
            add_inputs=add_inputs,
263
264
265
266
267
268
269
270
271
272
273
            **kwargs,
        )

    def add_lora_logits(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: torch.Tensor,
        lora_b_stacked: torch.Tensor,
        scale,
        *,
274
        buffer: torch.Tensor | None = None,
275
276
        **kwargs,
    ) -> None:
277
278
        """
        Applies lora  specifically for LogitsProcessorWithLoRA.
279

280
281
282
283
284
285
286
287
        Semantics:
            buffer = (x @ lora_a_stacked) * scale
            y += buffer @ lora_b_stacked

        Args:
            y (torch.Tensor): Output tensor.
            x (torch.Tensor): Input tensor.
            lora_a_stacked (torch.Tensor): lora_a's weights.
Jee Jee Li's avatar
Jee Jee Li committed
288
            lora_b_stacked (torch.Tensor): lora_b's weights.
289
            scale (float): Scaling factor.
Jee Jee Li's avatar
Jee Jee Li committed
290
            buffer (Optional[torch.Tensor]): Default to None.
291
292
293
294
295
        """
        y_org = y
        y = y.view(-1, y.shape[-1])
        x = x.view(-1, x.shape[-1])
        r = lora_b_stacked.size(-1)
296
297
298
299
300
301
302
303
304

        assert buffer is None, (
            "To minimize overhead, the buffer should be created by "
            ".add_lora_linear() instead of being passed in."
        )
        # We set the buffer to be float32 by default, refer to:
        # https://github.com/triton-lang/triton/issues/1387
        # Note: buffer is zeroed inside the shrink op
        buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
305

306
307
308
309
        lora_shrink(
            x,
            [lora_a_stacked],
            buffer.unsqueeze(dim=0),
310
311
312
            *self.prompt_mapping_meta.meta_args(
                x.size(0), self.lora_config.specialize_active_lora
            ),
313
314
            scale,
        )
315

316
317
318
319
        lora_expand(
            buffer.unsqueeze(dim=0),
            [lora_b_stacked],
            y,
320
321
322
            *self.prompt_mapping_meta.meta_args(
                buffer.size(0), self.lora_config.specialize_active_lora
            ),
323
324
            add_inputs=True,
        )
325
        y = y.view_as(y_org)
326
327
328
329
330
331
332
333

    def moe_lora_align_block_size(
        self,
        topk_ids: torch.Tensor,
        num_tokens: int,
        block_size: int,
        num_experts: int,
        max_loras: int,
334
        adapter_enabled: torch.Tensor,
335
336
        expert_map: torch.Tensor | None = None,
        pad_sorted_ids: bool = False,
337
338
        naive_block_assignment: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
339
340
341
342
        """
        Aligns tokens and experts into block-sized chunks for LoRA-based
        mixture-of-experts (MoE) execution.
        """
343
344
345
346
        (token_lora_mapping, _, _, _, lora_ids, _, _) = (
            self.token_mapping_meta.meta_args(
                num_tokens, self.lora_config.specialize_active_lora
            )
347
        )
348
349
350
351
352
353
354
355
        if naive_block_assignment:
            expert_ids = topk_ids.reshape(-1)
            sorted_ids = None
            num_tokens_post_pad = None
        else:
            max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
            if pad_sorted_ids:
                max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
356
357
            if topk_ids.numel() < num_experts:
                max_num_tokens_padded = topk_ids.numel() * block_size
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
            sorted_ids = torch.empty(
                (max_loras * max_num_tokens_padded,),
                dtype=torch.int32,
                device=topk_ids.device,
            )
            max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
            # Expert ids must be set default to -1 to prevent a blank block
            expert_ids = torch.empty(
                (max_loras * max_num_m_blocks,),
                dtype=torch.int32,
                device=topk_ids.device,
            )
            num_tokens_post_pad = torch.empty(
                (max_loras), dtype=torch.int32, device=topk_ids.device
            )

            ops.moe_lora_align_block_size(
                topk_ids,
                token_lora_mapping,
                num_experts,
                block_size,
                max_loras,
                max_num_tokens_padded,
                max_num_m_blocks,
                sorted_ids,
                expert_ids,
                num_tokens_post_pad,
                adapter_enabled,
                lora_ids,
            )
            if expert_map is not None:
                expert_ids = expert_map[expert_ids]

        return None, sorted_ids, expert_ids, num_tokens_post_pad
392
393
394
395
396

    def add_lora_fused_moe(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
397
398
        lora_a_stacked: tuple[torch.Tensor, ...],
        lora_b_stacked: tuple[torch.Tensor, ...],
399
        topk_weights: torch.Tensor,
400
        sorted_token_ids: torch.Tensor | None,
401
        expert_ids: torch.Tensor,
402
        num_tokens_post_padded: torch.Tensor | None,
403
404
        max_lora_rank: int,
        top_k_num: int,
405
406
        shrink_config,
        expand_config,
407
        adapter_enabled: torch.Tensor,
408
        mul_routed_weight=False,
409
410
        fully_sharded: bool = False,
        offset: int = 0,
411
        token_lora_mapping: torch.Tensor | None = None,
412
413
414
415
    ):
        """
        Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
        """
416
417
418
419
420
421
422
        (
            token_lora_mapping_meta,
            _,
            _,
            _,
            lora_ids,
            _,
423
424
425
426
            num_active_loras,
        ) = self.token_mapping_meta.meta_args(
            x.size(0), self.lora_config.specialize_active_lora
        )
427
428
        if token_lora_mapping is None:
            token_lora_mapping = token_lora_mapping_meta
429
430
431
432
433
434
435
436
437
        fused_moe_lora(
            y,
            x,
            lora_a_stacked,
            lora_b_stacked,
            topk_weights,
            sorted_token_ids,
            expert_ids,
            num_tokens_post_padded,
438
            token_lora_mapping,
439
440
            max_lora_rank,
            top_k_num,
441
            lora_ids,
442
            num_active_loras,
443
            adapter_enabled,
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            shrink_config.get("BLOCK_SIZE_M", 64),
            shrink_config.get("BLOCK_SIZE_N", 64),
            shrink_config.get("BLOCK_SIZE_K", 32),
            shrink_config.get("GROUP_SIZE_M", 8),
            shrink_config.get("NUM_WARPS", 4),
            shrink_config.get("NUM_STAGES", 3),
            shrink_config.get("SPLIT_K", 1),
            expand_config.get("BLOCK_SIZE_M", 64),
            expand_config.get("BLOCK_SIZE_N", 64),
            expand_config.get("BLOCK_SIZE_K", 32),
            expand_config.get("GROUP_SIZE_M", 8),
            expand_config.get("NUM_WARPS", 4),
            expand_config.get("NUM_STAGES", 3),
            expand_config.get("SPLIT_K", 1),
458
            mul_routed_weight,
459
460
            fully_sharded,
            offset,
461
        )
Jee Jee Li's avatar
Jee Jee Li committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697

    def add_lora_w13(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        lora_b_stacked: tuple[torch.Tensor, ...],
        topk_ids: torch.Tensor,
        topk_weights: torch.Tensor,
        expert_map: torch.Tensor | None,
        w1: torch.Tensor,
        w2: torch.Tensor,
        num_tokens: int,
        top_k_num: int,
        max_loras: int,
        adapter_enabled: torch.Tensor,
        local_num_experts: int,
        top_k: int,
        num_slices: int,
        fully_sharded: bool,
        use_tuned_config: bool,
    ) -> tuple[
        torch.Tensor | None,
        torch.Tensor | None,
        torch.Tensor | None,
        torch.Tensor | None,
    ]:
        import functools

        from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
        from vllm.lora.ops.triton_ops.utils import (
            _normalize_lora_config_keys,
            get_lora_op_configs,
        )
        from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str

        config_dtype = _get_config_dtype_str(
            dtype=x.dtype,
            use_fp8_w8a8=False,
            use_int8_w8a16=False,
            use_int4_w4a16=False,
        )
        max_lora_rank = lora_a_stacked[0].shape[-2]

        if use_tuned_config:
            shrink_config = get_lora_op_configs(
                op_type="fused_moe_lora_w13_shrink",
                max_loras=max_loras,
                batch=num_tokens,
                hidden_size=x.shape[-1],
                rank=max_lora_rank,
                num_slices=num_slices,
                moe_intermediate_size=lora_b_stacked[0].shape[-2],
            )
            expand_config = get_lora_op_configs(
                op_type="fused_moe_lora_w13_expand",
                max_loras=max_loras,
                batch=num_tokens,
                hidden_size=x.shape[-1],
                rank=max_lora_rank,
                num_slices=num_slices,
                moe_intermediate_size=lora_b_stacked[0].shape[-2],
            )
        else:
            get_config = functools.partial(
                try_get_optimal_moe_lora_config,
                w1_shape=w1.shape,
                w2_shape=w2.shape,
                rank=max_lora_rank,
                top_k=top_k,
                dtype=config_dtype,
                M=num_tokens,
            )
            shrink_config = get_config(op_type="fused_moe_lora_w13_shrink")
            expand_config = get_config(op_type="fused_moe_lora_w13_expand")

        shrink_config = _normalize_lora_config_keys(shrink_config)
        expand_config = _normalize_lora_config_keys(expand_config)

        SPARSITY_FACTOR = 8
        naive_block_assignment = (
            expert_map is None
            and num_tokens * top_k * SPARSITY_FACTOR <= local_num_experts * max_loras
        )

        (
            token_lora_mapping,
            sorted_token_ids_lora,
            expert_ids_lora,
            num_tokens_post_padded_lora,
        ) = self.moe_lora_align_block_size(
            topk_ids,
            num_tokens,
            int(shrink_config.get("BLOCK_SIZE_M") or 64),
            local_num_experts,
            max_loras,
            adapter_enabled,
            expert_map,
            naive_block_assignment=naive_block_assignment,
        )

        _sorted = sorted_token_ids_lora
        _eids = expert_ids_lora
        if _sorted is not None:
            _eids = _eids.view(max_loras, -1)
            _sorted = _sorted.view(max_loras, -1)

        self.add_lora_fused_moe(
            y.view(-1, top_k_num, y.shape[-1]),
            x,
            lora_a_stacked,
            lora_b_stacked,
            topk_weights,
            _sorted,
            _eids,
            num_tokens_post_padded_lora,
            max_lora_rank,
            top_k,
            shrink_config,
            expand_config,
            adapter_enabled,
            fully_sharded=fully_sharded,
            token_lora_mapping=token_lora_mapping,
        )

        return (
            sorted_token_ids_lora,
            expert_ids_lora,
            num_tokens_post_padded_lora,
            token_lora_mapping,
        )

    def add_lora_w2(
        self,
        y: torch.Tensor,
        x: torch.Tensor,
        lora_a_stacked: tuple[torch.Tensor, ...],
        lora_b_stacked: tuple[torch.Tensor, ...],
        topk_weights: torch.Tensor,
        sorted_token_ids_lora: torch.Tensor | None,
        expert_ids_lora: torch.Tensor | None,
        num_tokens_post_padded_lora: torch.Tensor | None,
        token_lora_mapping: torch.Tensor | None,
        num_tokens: int,
        w1: torch.Tensor,
        w2: torch.Tensor,
        top_k_num: int,
        max_loras: int,
        adapter_enabled: torch.Tensor,
        top_k: int,
        fully_sharded: bool,
        tp_rank: int,
        use_tuned_config: bool,
    ) -> None:
        import functools

        from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
        from vllm.lora.ops.triton_ops.utils import (
            _normalize_lora_config_keys,
            get_lora_op_configs,
        )
        from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str

        config_dtype = _get_config_dtype_str(
            dtype=x.dtype,
            use_fp8_w8a8=False,
            use_int8_w8a16=False,
            use_int4_w4a16=False,
        )
        max_lora_rank = lora_a_stacked[0].shape[-2]

        if use_tuned_config:
            shrink_config = get_lora_op_configs(
                op_type="fused_moe_lora_w2_shrink",
                max_loras=max_loras,
                batch=num_tokens,
                hidden_size=y.shape[-1],
                rank=max_lora_rank,
                num_slices=1,
                moe_intermediate_size=lora_a_stacked[0].shape[-1],
            )
            expand_config = get_lora_op_configs(
                op_type="fused_moe_lora_w2_expand",
                max_loras=max_loras,
                batch=num_tokens,
                hidden_size=y.shape[-1],
                rank=max_lora_rank,
                num_slices=1,
                moe_intermediate_size=lora_a_stacked[0].shape[-1],
            )
        else:
            get_config = functools.partial(
                try_get_optimal_moe_lora_config,
                w1_shape=w1.shape,
                w2_shape=w2.shape,
                rank=max_lora_rank,
                top_k=top_k,
                dtype=config_dtype,
                M=num_tokens,
            )
            shrink_config = get_config(op_type="fused_moe_lora_w2_shrink")
            expand_config = get_config(op_type="fused_moe_lora_w2_expand")

        shrink_config = _normalize_lora_config_keys(shrink_config)
        expand_config = _normalize_lora_config_keys(expand_config)

        _sorted = sorted_token_ids_lora
        _eids = expert_ids_lora
        if _sorted is not None:
            assert _eids is not None
            _eids = _eids.view(max_loras, -1)
            _sorted = _sorted.view(max_loras, -1)

        # w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded
        shard_size = lora_b_stacked[0].shape[-2]
        offset = shard_size * tp_rank if fully_sharded else 0

        self.add_lora_fused_moe(
            y,
            x,
            lora_a_stacked,
            lora_b_stacked,
            topk_weights,
            _sorted,
            _eids,
            num_tokens_post_padded_lora,
            max_lora_rank,
            top_k,
            shrink_config,
            expand_config,
            adapter_enabled,
            True,  # mul_routed_weight
            fully_sharded=fully_sharded,
            offset=offset,
            token_lora_mapping=token_lora_mapping,
        )