fused_moe.py 26.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
15
from vllm.distributed.utils import divide
16
from vllm.lora.layers.base import BaseLayerWithLoRA
17
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
18
19
20
21
22
23
24
25
26
27
28
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
    _get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
    modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
    modular_triton_fused_moe,
    try_get_optimal_moe_config,
)
29
30
31
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
    FusedMoEModularMethod,
)
32
33
34
35
36
37


class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer
38
39
40
41

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )
42
43
44
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = base_layer.w2_weight.device
45
        self._w13_slices = 2
46
47
        self._inject_lora_into_fused_moe()

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
        normalized_config = {}
        for key, value in config.items():
            if key.islower():
                if key.startswith("block_"):
                    normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
                else:
                    normalized_key = key.upper()
            else:
                normalized_key = key
            normalized_config[normalized_key] = value
        return normalized_config

    def _get_lora_moe_configs(
        self,
        op_prefix: str,
64
65
        num_loras: int,
        rank: int,
66
67
68
69
70
71
72
        num_slices: int,
        M: int,
        layer: FusedMoE,
        top_k: int,
        config_dtype: str,
    ):
        if envs.VLLM_TUNED_CONFIG_FOLDER:
73
74
            hidden_size = layer.hidden_size
            intermediate_size = layer.intermediate_size_per_partition
75
76
            shrink_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_shrink",
77
                max_loras=num_loras,
78
                batch=M,
79
80
                hidden_size=hidden_size,
                rank=rank,
81
                num_slices=num_slices,
82
                moe_intermediate_size=intermediate_size,
83
84
85
            )
            expand_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_expand",
86
                max_loras=num_loras,
87
                batch=M,
88
89
                hidden_size=hidden_size,  # lora_a_stacked.shape[-1],
                rank=rank,
90
                num_slices=num_slices,
91
                moe_intermediate_size=intermediate_size,  # lora_b_stacked.shape[-2],
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            )
        else:  # fall back to the default config
            get_config_func = functools.partial(
                try_get_optimal_moe_config,
                layer.w13_weight.size(),
                layer.w2_weight.size(),
                top_k,
                config_dtype,
                block_shape=layer.quant_method.moe_quant_config.block_shape,
            )
            shrink_config = get_config_func(M)
            expand_config = get_config_func(M)
        shrink_config = self._normalize_keys(shrink_config)
        expand_config = self._normalize_keys(expand_config)
        return shrink_config, expand_config

108
109
110
111
    def _inject_lora_into_fused_moe(self):
        moe_state_dict = {}
        top_k = self.base_layer.top_k

112
113
        self.base_layer.ensure_moe_quant_config_init()
        quant_config = self.base_layer.quant_method.moe_quant_config
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        m_fused_moe_fn = (
            modular_triton_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
            if not quant_config.use_mxfp4_w4a16
            else modular_marlin_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
        )

        def fwd_decorator(layer, func):
            def wrapper(*args, **kwargs):
                moe_state_dict["hidden_states"] = kwargs["hidden_states"]
                moe_state_dict["topk_ids"] = kwargs["topk_ids"]
                moe_state_dict["topk_weights"] = kwargs["topk_weights"]
                moe_state_dict["expert_map"] = kwargs["expert_map"]
                moe_state_dict["apply_router_weight_on_input"] = kwargs[
                    "apply_router_weight_on_input"
                ]
                result = func(*args, **kwargs)
                return result

            return wrapper

        def act_decorator(layer, func):
            def wrapper(*args, **kwargs):
                _, output, input = args

                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                curr_topk_ids = moe_state_dict["topk_ids"]
146

147
148
149
150
151
152
153
154
155
156
157
                expert_map = moe_state_dict["expert_map"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
158
                max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
159
160
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w13",
161
162
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
163
                    num_slices=self._w13_slices,
164
165
166
167
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
168
169
                )

170
                # get the block size of m from customized config or default config
171
172
173
174
175
176
177
                (
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                ) = self.punica_wrapper.moe_lora_align_block_size(
                    curr_topk_ids,
                    num_tokens,
178
                    shrink_config["BLOCK_SIZE_M"],
179
                    self.base_layer.local_num_experts,
180
                    self.max_loras,
181
                    self.adapter_enabled,
182
183
184
185
186
187
188
189
190
                    expert_map,
                )

                moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
                moe_state_dict["expert_ids_lora"] = expert_ids_lora
                moe_state_dict["num_tokens_post_padded_lora"] = (
                    num_tokens_post_padded_lora
                )

191
192
193
                expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
                #
194
195
196
197

                self.punica_wrapper.add_lora_fused_moe(
                    input.view(-1, top_k, input.shape[-1]),
                    hidden_states,
198
199
                    self.w13_lora_a_stacked,
                    self.w13_lora_b_stacked,
200
201
202
203
204
205
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
206
207
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
208
                    self.adapter_enabled,
209
                    fully_sharded=self.fully_sharded,
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                )

                result = func(*args, **kwargs)

                moe_state_dict["intermediate_cache2"] = output
                return result

            return wrapper

        def moe_sum_decorator(layer, func):
            def wrapper(*args, **kwargs):
                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
233
                max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
234
235
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w2",
236
237
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
238
239
240
241
242
                    num_slices=1,
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
243
244
245
246
247
248
249
                )

                sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
                expert_ids_lora = moe_state_dict["expert_ids_lora"]
                num_tokens_post_padded_lora = moe_state_dict[
                    "num_tokens_post_padded_lora"
                ]
250
251
252

                expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
253
254
                intermediate_cache2 = moe_state_dict["intermediate_cache2"]
                intermediate_cache3 = args[0]
255
256
257

                shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)

258
259
260
                self.punica_wrapper.add_lora_fused_moe(
                    intermediate_cache3,
                    intermediate_cache2,
261
262
                    self.w2_lora_a_stacked,
                    self.w2_lora_b_stacked,
263
264
265
266
267
268
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
269
270
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
271
                    self.adapter_enabled,
272
                    True,
273
274
                    fully_sharded=self.fully_sharded,
                    offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
                )

                result = func(*args, **kwargs)
                return result

            return wrapper

        fused_experts = m_fused_moe_fn.fused_experts

        m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
        fused_experts.activation = act_decorator(
            self.base_layer, fused_experts.activation
        )
        fused_experts.moe_sum = moe_sum_decorator(
            self.base_layer, fused_experts.moe_sum
        )
291
292
        self.base_layer.quant_method = FusedMoEModularMethod(
            self.base_layer.quant_method, m_fused_moe_fn
293
294
        )

295
    def _create_lora_a_weights(
296
297
298
        self,
        max_loras: int,
        lora_config: LoRAConfig,
299
300
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
301
302
303
304
305
306
307
308
309
310
311
312
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.base_layer.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
313
314
315
316
317
318
319
320
321
322
323
324
325
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank,
                    self.base_layer.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
326
        )
327

328
329
    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
330
331
332
333
334
335
336
337
338
339
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
340
            for _ in range(self._w13_slices)
341
        )
342
343
344
345
346
347
348
349
350
351
352
353
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
354
355
            ),
        )
356
357
358
359
360
361
362
363
364
365
366
367
368

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
369
370
        )

371
372
        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
373
374
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
375
        # TODO Optimize this section
376
377
378
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
379
            for experts_id in range(self.base_layer.local_num_experts):
380
                # gate_proj,down_proj,up_proj
381
382
383
384
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
385
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
386
                )
387

388
389
390
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
391
392
393
394
395
396
397
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
398
399
400
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )
401
402
403

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
404
        for pos in range(self._w13_slices):
405
406
407
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

408
409
        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
410
        self.adapter_enabled[index] = 0
411
412
413
414

    def set_lora(
        self,
        index: int,
415
416
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
417
418
    ):
        """Overwrites lora tensors at index."""
419
420
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
421
422
        self.reset_lora(index)
        self.adapter_enabled[index] = 1
423
424
425
426
427
428
429
430
        for eid in range(len(lora_a) // 3):
            w1_lora_a = lora_a[eid * 3]
            w2_lora_a = lora_a[eid * 3 + 1]
            w3_lora_a = lora_a[eid * 3 + 2]
            w1_lora_b = lora_b[eid * 3]
            w2_lora_b = lora_b[eid * 3 + 1]
            w3_lora_b = lora_b[eid * 3 + 2]

431
432
433
434
            # Handle the case of adding LoRA to only a subset of experts
            if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
                continue

435
436
437
438
439
440
441
442
443
            if self.tp_size > 1:
                shard_size = self.base_layer.intermediate_size_per_partition
                start_idx = self.tp_rank * shard_size
                end_idx = (self.tp_rank + 1) * shard_size

                w1_lora_b = w1_lora_b[start_idx:end_idx, :]
                w3_lora_b = w3_lora_b[start_idx:end_idx, :]
                w2_lora_a = w2_lora_a[:, start_idx:end_idx]

444
445
446
                if self.fully_sharded:
                    # Based on S-LoRA, we slice W1 and W3 A along the rank dim,
                    # and W2 B along the hidden_size dim.
447
                    w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
448
449
450
451
452
                    w13_start_idx = self.tp_rank * w13_shard_size
                    w13_end_idx = (self.tp_rank + 1) * w13_shard_size
                    w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
                    w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]

453
                    w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
454
455
456
                    w2_start_idx = self.tp_rank * w2_shard_size
                    w2_end_idx = (self.tp_rank + 1) * w2_shard_size
                    w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
457
458
            # w1 lora_a
            self.w13_lora_a_stacked[0][
459
460
                index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
            ].copy_(w1_lora_a, non_blocking=True)
461
462
            # w3 lora_a
            self.w13_lora_a_stacked[1][
463
464
465
                index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
            ].copy_(w3_lora_a, non_blocking=True)

466
467
            # w1 lora_b
            self.w13_lora_b_stacked[0][
468
469
                index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
            ].copy_(w1_lora_b, non_blocking=True)
470
471
            # w3 lora_b
            self.w13_lora_b_stacked[1][
472
473
                index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
            ].copy_(w3_lora_b, non_blocking=True)
474

475
            self.w2_lora_a_stacked[0][
476
477
478
                index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
            ].copy_(w2_lora_a, non_blocking=True)

479
            self.w2_lora_b_stacked[0][
480
481
482
                index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
            ].copy_(w2_lora_b, non_blocking=True)

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
        return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

    @property
    def _shared_experts(self):
        return self.base_layer._shared_experts

    @property
    def quant_method(self):
        return self.base_layer.quant_method

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

501
502
503
504
505
506
507
508
509
510
511
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # return type(source_layer) is FusedMoE

512
        return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0

        sliced_rank = current_lora_rank // self.tp_size
        start_idx = self.tp_rank * sliced_rank
        end_idx = (self.tp_rank + 1) * sliced_rank
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        if is_interleave:
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        current_lora_size = w2_lora_b.shape[1]

        sliced_size = current_lora_size // self.tp_size
        start_idx = self.tp_rank * sliced_size
        end_idx = (self.tp_rank + 1) * sliced_size
        return w2_lora_b[:, start_idx:end_idx, :]

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]
        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        # (num_experts,rank,input_size)
        w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
        w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
        # (output_size,num_experts,rank)
        w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
        w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
        # (num_experts,output_size,rank)
        w13_lora_b = w13_lora_b.permute(1, 0, 2)
        w2_lora_b = w2_lora_b.permute(1, 0, 2)

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)
677
678

    @property
679
680
681
682
683
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]
684
685

    @property
686
687
688
689
690
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
691
692

    @property
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-2]

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        return type(source_layer) is FusedMoE and len(packed_modules_list) == 1