fused_moe.py 29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
15
from vllm.distributed.utils import divide
16
from vllm.lora.layers.base import BaseLayerWithLoRA
17
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
18
19
20
21
22
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
    _get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
23
    MarlinExperts,
24
25
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
26
    TritonExperts,
27
)
28
29
30
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
    FusedMoEModularMethod,
)
31
32
33
34
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
    UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
35
    FusedMoEKernel,
36
37
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
38
    MoEPrepareAndFinalizeNoDPEPModular,
39
)
40

41
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
42

43
44
45
46
47

class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer
48
49
50
51

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )
52
53
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
54
        self.device = _get_lora_device(base_layer)
55
56
57
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
58
59
        self._inject_lora_into_fused_moe()

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
        normalized_config = {}
        for key, value in config.items():
            if key.islower():
                if key.startswith("block_"):
                    normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
                else:
                    normalized_key = key.upper()
            else:
                normalized_key = key
            normalized_config[normalized_key] = value
        return normalized_config

    def _get_lora_moe_configs(
        self,
        op_prefix: str,
76
77
        num_loras: int,
        rank: int,
78
79
80
81
82
83
84
        num_slices: int,
        M: int,
        layer: FusedMoE,
        top_k: int,
        config_dtype: str,
    ):
        if envs.VLLM_TUNED_CONFIG_FOLDER:
85
            hidden_size = layer.hidden_size
86
87
88
89
90
            intermediate_size = (
                self.w2_lora_a_stacked[0].shape[-1]
                if op_prefix == "w2"
                else self.w13_lora_b_stacked[0].shape[-2]
            )
91
92
            shrink_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_shrink",
93
                max_loras=num_loras,
94
                batch=M,
95
96
                hidden_size=hidden_size,
                rank=rank,
97
                num_slices=num_slices,
98
                moe_intermediate_size=intermediate_size,
99
100
101
            )
            expand_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_expand",
102
                max_loras=num_loras,
103
                batch=M,
104
105
                hidden_size=hidden_size,  # lora_a_stacked.shape[-1],
                rank=rank,
106
                num_slices=num_slices,
107
                moe_intermediate_size=intermediate_size,  # lora_b_stacked.shape[-2],
108
109
110
            )
        else:  # fall back to the default config
            get_config_func = functools.partial(
111
112
113
114
115
116
117
                try_get_optimal_moe_lora_config,
                w1_shape=layer.w13_weight.size(),
                w2_shape=layer.w2_weight.size(),
                rank=rank,
                top_k=top_k,
                dtype=config_dtype,
                M=M,
118
119
                block_shape=layer.quant_method.moe_quant_config.block_shape,
            )
120
121
122
123
124
125
            shrink_config = get_config_func(
                op_type=f"fused_moe_lora_{op_prefix}_shrink"
            )
            expand_config = get_config_func(
                op_type=f"fused_moe_lora_{op_prefix}_expand"
            )
126
127
128
129
        shrink_config = self._normalize_keys(shrink_config)
        expand_config = self._normalize_keys(expand_config)
        return shrink_config, expand_config

130
131
132
133
    def _inject_lora_into_fused_moe(self):
        moe_state_dict = {}
        top_k = self.base_layer.top_k

134
135
        self.base_layer.ensure_moe_quant_config_init()
        quant_config = self.base_layer.quant_method.moe_quant_config
136

danisereb's avatar
danisereb committed
137
138
        if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
            # Use the existing modular kernel from the quant method
139
            m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
140
141
142
            # Don't let the kernel own shared experts so the runner can
            # overlap them with routed experts via a separate CUDA stream.
            m_fused_moe_fn.shared_experts = None
danisereb's avatar
danisereb committed
143
        else:
144
145
146
            # Create a new modular kernel via select_gemm_impl.
            # Don't pass shared_experts to the kernel so the runner can
            # overlap them with routed experts via a separate CUDA stream.
147
148
            prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
            m_fused_moe_fn = FusedMoEKernel(
danisereb's avatar
danisereb committed
149
150
151
152
153
154
                prepare_finalize,
                self.base_layer.quant_method.select_gemm_impl(
                    prepare_finalize, self.base_layer
                ),
            )

155
156
        if quant_config.use_mxfp4_w4a16:
            assert isinstance(
157
158
                m_fused_moe_fn.impl.fused_experts,
                (MarlinExperts, UnfusedOAITritonExperts),
159
            )
160
        else:
161
            assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        def fwd_decorator(layer, func):
            def wrapper(*args, **kwargs):
                moe_state_dict["hidden_states"] = kwargs["hidden_states"]
                moe_state_dict["topk_ids"] = kwargs["topk_ids"]
                moe_state_dict["topk_weights"] = kwargs["topk_weights"]
                moe_state_dict["expert_map"] = kwargs["expert_map"]
                moe_state_dict["apply_router_weight_on_input"] = kwargs[
                    "apply_router_weight_on_input"
                ]
                result = func(*args, **kwargs)
                return result

            return wrapper

        def act_decorator(layer, func):
            def wrapper(*args, **kwargs):
                _, output, input = args

                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                curr_topk_ids = moe_state_dict["topk_ids"]
184

185
186
187
188
189
190
191
192
193
194
195
                expert_map = moe_state_dict["expert_map"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
196
                max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
197
198
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w13",
199
200
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
201
                    num_slices=self._w13_slices,
202
203
204
205
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
206
207
                )

208
209
210
211
212
213
214
215
216
                # SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
                # activates only a small fraction of total experts * loras.
                SPARSITY_FACTOR = 8
                naive_block_assignment = (
                    expert_map is None
                    and num_tokens * top_k * SPARSITY_FACTOR
                    <= self.base_layer.local_num_experts * self.max_loras
                )

217
                # get the block size of m from customized config or default config
218
                (
219
                    token_lora_mapping,
220
221
222
223
224
225
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                ) = self.punica_wrapper.moe_lora_align_block_size(
                    curr_topk_ids,
                    num_tokens,
226
                    shrink_config["BLOCK_SIZE_M"],
227
                    self.base_layer.local_num_experts,
228
                    self.max_loras,
229
                    self.adapter_enabled,
230
                    expert_map,
231
                    naive_block_assignment=naive_block_assignment,
232
233
234
235
236
237
238
                )

                moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
                moe_state_dict["expert_ids_lora"] = expert_ids_lora
                moe_state_dict["num_tokens_post_padded_lora"] = (
                    num_tokens_post_padded_lora
                )
239
                moe_state_dict["token_lora_mapping"] = token_lora_mapping
240

241
242
243
244
245
                if sorted_token_ids_lora is not None:
                    expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                    sorted_token_ids_lora = sorted_token_ids_lora.view(
                        self.max_loras, -1
                    )
246
                #
247
248
249
250

                self.punica_wrapper.add_lora_fused_moe(
                    input.view(-1, top_k, input.shape[-1]),
                    hidden_states,
251
252
                    self.w13_lora_a_stacked,
                    self.w13_lora_b_stacked,
253
254
255
256
257
258
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
259
260
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
261
                    self.adapter_enabled,
262
                    fully_sharded=self.fully_sharded,
263
                    token_lora_mapping=token_lora_mapping,
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                )

                result = func(*args, **kwargs)

                moe_state_dict["intermediate_cache2"] = output
                return result

            return wrapper

        def moe_sum_decorator(layer, func):
            def wrapper(*args, **kwargs):
                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
287
                max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
288
289
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w2",
290
291
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
292
293
294
295
296
                    num_slices=1,
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
297
298
299
300
301
302
303
                )

                sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
                expert_ids_lora = moe_state_dict["expert_ids_lora"]
                num_tokens_post_padded_lora = moe_state_dict[
                    "num_tokens_post_padded_lora"
                ]
304
                token_lora_mapping = moe_state_dict.get("token_lora_mapping")
305

306
307
308
309
310
                if sorted_token_ids_lora is not None:
                    expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                    sorted_token_ids_lora = sorted_token_ids_lora.view(
                        self.max_loras, -1
                    )
311
312
                intermediate_cache2 = moe_state_dict["intermediate_cache2"]
                intermediate_cache3 = args[0]
313
314
315

                shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)

316
317
318
                self.punica_wrapper.add_lora_fused_moe(
                    intermediate_cache3,
                    intermediate_cache2,
319
320
                    self.w2_lora_a_stacked,
                    self.w2_lora_b_stacked,
321
322
323
324
325
326
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
327
328
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
329
                    self.adapter_enabled,
330
                    True,
331
332
                    fully_sharded=self.fully_sharded,
                    offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
333
                    token_lora_mapping=token_lora_mapping,
334
335
336
337
338
339
340
                )

                result = func(*args, **kwargs)
                return result

            return wrapper

341
        fused_experts = m_fused_moe_fn.impl.fused_experts
342

343
        m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
344
345
346
347
348
349
        fused_experts.activation = act_decorator(
            self.base_layer, fused_experts.activation
        )
        fused_experts.moe_sum = moe_sum_decorator(
            self.base_layer, fused_experts.moe_sum
        )
350
351
352
        # TODO(bnell): find a less intrusive way to handle this.
        self.base_layer._replace_quant_method(
            FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn)
353
354
        )

355
    def _create_lora_a_weights(
356
357
358
        self,
        max_loras: int,
        lora_config: LoRAConfig,
359
360
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
361
362
363
364
365
366
367
368
369
370
371
372
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.base_layer.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
373
374
375
376
377
378
379
380
381
382
383
384
385
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank,
                    self.base_layer.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
386
        )
387

388
389
    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
390
391
392
393
394
395
396
397
398
399
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
400
            for _ in range(self._w13_slices)
401
        )
402
403
404
405
406
407
408
409
410
411
412
413
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
414
415
            ),
        )
416
417
418
419
420
421
422
423
424
425
426
427
428

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
429
430
        )

431
432
        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
433
434
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
435
        # TODO Optimize this section
436
437
438
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
439
            for experts_id in range(self.base_layer.local_num_experts):
440
441
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
442
443
444
445
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
446
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
447
                )
448

449
450
451
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
452
453
454
455
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

456
457
458
459
460
461
462
463
                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )
464

465
466
467
468
469
470
471
472
473
474
475
    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
476
477
478
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
513
514
515
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
516
517
518

        return w2_lora_b[:, start_idx:end_idx, :]

519
520
    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
521
        for pos in range(self._w13_slices):
522
523
524
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

525
526
        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
527
        self.adapter_enabled[index] = 0
528

529
530
    #

531
532
533
    def set_lora(
        self,
        index: int,
534
535
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
536
537
    ):
        """Overwrites lora tensors at index."""
538
        # Make mypy happy
539
540
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
541

542
543
        self.reset_lora(index)
        self.adapter_enabled[index] = 1
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

570
571
572
573
574
575
576
577
578
579
580
581
        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)
582
583
584
585
586
587
588
589

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
        return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

    @property
    def _shared_experts(self):
        return self.base_layer._shared_experts

    @property
    def quant_method(self):
        return self.base_layer.quant_method

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

609
610
611
612
613
614
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
615
        model_config: PretrainedConfig | None = None,
616
617
618
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

619
620
        # source_layer is FusedMoE or SharedFusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
621

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
664
665
666

        assert isinstance(model_config, PretrainedConfig)
        self._base_model = model_config.architectures[0]
667
668
669
670
671
672
673
674
675
676
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

677
    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
678
679
680
681
682
683
684
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
685
686
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
725
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)
743
744

    @property
745
746
747
748
749
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]
750
751

    @property
752
753
754
755
756
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
757
758

    @property
759
760
761
762
763
764
765
766
767
768
769
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
770
        return self.base_layer.hidden_size
771
772
773
774
775
776
777

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
778
        model_config: PretrainedConfig | None = None,
779
780
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
781
782
        # source_layer is FusedMoE or SharedFusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1