fused_moe.py 15.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
    FUSED_MOE_UNQUANTIZED_CONFIG,
    _get_config_dtype_str,
    mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
    modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
    modular_triton_fused_moe,
    try_get_optimal_moe_config,
)
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config


class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = base_layer.w2_weight.device
        self._inject_lora_into_fused_moe()

    def _inject_lora_into_fused_moe(self):
        moe_state_dict = {}
        top_k = self.base_layer.top_k

        if self.base_layer.quant_config is None:
            quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
        elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
            quant_config = self.base_layer.quant_config
        else:
            quant_config = mxfp4_w4a16_moe_quant_config(
                w1_bias=self.base_layer.w13_bias,
                w2_bias=self.base_layer.w2_bias,
                w1_scale=self.base_layer.w13_weight_scale,
                w2_scale=self.base_layer.w2_weight_scale,
            )

        m_fused_moe_fn = (
            modular_triton_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
            if not quant_config.use_mxfp4_w4a16
            else modular_marlin_fused_moe(
                quant_config, shared_experts=self.base_layer.shared_experts
            )
        )

        def fwd_decorator(layer, func):
            def wrapper(*args, **kwargs):
                moe_state_dict["hidden_states"] = kwargs["hidden_states"]
                moe_state_dict["topk_ids"] = kwargs["topk_ids"]
                moe_state_dict["topk_weights"] = kwargs["topk_weights"]
                moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
                moe_state_dict["expert_map"] = kwargs["expert_map"]
                moe_state_dict["apply_router_weight_on_input"] = kwargs[
                    "apply_router_weight_on_input"
                ]
                result = func(*args, **kwargs)
                return result

            return wrapper

        def act_decorator(layer, func):
            def wrapper(*args, **kwargs):
                _, output, input = args

                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                curr_topk_ids = moe_state_dict["topk_ids"]
                global_num_experts = moe_state_dict["global_num_experts"]
                expert_map = moe_state_dict["expert_map"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)

                get_config_func = functools.partial(
                    try_get_optimal_moe_config,
                    layer.w13_weight.size(),
                    layer.w2_weight.size(),
                    top_k,
                    config_dtype,
                    block_shape=layer.quant_method.moe_quant_config.block_shape,
                )

111
                max_loras = self.w1_lora_a_stacked.shape[0]
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                config = get_config_func(M)
                (
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                ) = self.punica_wrapper.moe_lora_align_block_size(
                    curr_topk_ids,
                    num_tokens,
                    config["BLOCK_SIZE_M"],
                    global_num_experts,
                    max_loras,
                    expert_map,
                )

                moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
                moe_state_dict["expert_ids_lora"] = expert_ids_lora
                moe_state_dict["num_tokens_post_padded_lora"] = (
                    num_tokens_post_padded_lora
                )

                w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
                w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
                max_lora_rank = self.w1_lora_a_stacked.shape[-2]
                expert_ids_lora = expert_ids_lora.view(max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)

                self.punica_wrapper.add_lora_fused_moe(
                    input.view(-1, top_k, input.shape[-1]),
                    hidden_states,
                    w13_lora_a_stacked,
                    w13_lora_b_stacked,
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    config,
                )

                result = func(*args, **kwargs)

                moe_state_dict["intermediate_cache2"] = output
                return result

            return wrapper

        def moe_sum_decorator(layer, func):
            def wrapper(*args, **kwargs):
                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)

                get_config_func = functools.partial(
                    try_get_optimal_moe_config,
                    layer.w13_weight.size(),
                    layer.w2_weight.size(),
                    top_k,
                    config_dtype,
                    block_shape=layer.quant_method.moe_quant_config.block_shape,
                )

                config = get_config_func(M)

                sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
                expert_ids_lora = moe_state_dict["expert_ids_lora"]
                num_tokens_post_padded_lora = moe_state_dict[
                    "num_tokens_post_padded_lora"
                ]
190
                max_loras = self.w1_lora_a_stacked.shape[0]
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                expert_ids_lora = expert_ids_lora.view(max_loras, -1)
                sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
                intermediate_cache2 = moe_state_dict["intermediate_cache2"]
                intermediate_cache3 = args[0]
                max_lora_rank = self.w1_lora_a_stacked.shape[-2]
                self.punica_wrapper.add_lora_fused_moe(
                    intermediate_cache3,
                    intermediate_cache2,
                    [self.w2_lora_a_stacked],
                    [self.w2_lora_b_stacked],
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    config,
                    True,
                )

                result = func(*args, **kwargs)
                return result

            return wrapper

        fused_experts = m_fused_moe_fn.fused_experts

        m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
        fused_experts.activation = act_decorator(
            self.base_layer, fused_experts.activation
        )
        fused_experts.moe_sum = moe_sum_decorator(
            self.base_layer, fused_experts.moe_sum
        )

        self.base_layer.quant_method.old_fused_experts = (
            self.base_layer.quant_method.fused_experts
        )
        self.base_layer.quant_method.fused_experts = m_fused_moe_fn

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )

        self.w1_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w1_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.intermediate_size_per_partition,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        self.w2_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.intermediate_size_per_partition,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w2_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.hidden_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        self.w3_lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                lora_config.max_lora_rank,
                self.base_layer.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.w3_lora_b_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.global_num_experts,
                self.base_layer.intermediate_size_per_partition,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.base_layer.global_num_experts):
                # gate_proj,down_proj,up_proj
                self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
                self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
                self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])

                self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
                self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
                self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        self.w1_lora_a_stacked[index] = 0
        self.w1_lora_b_stacked[index] = 0
        self.w3_lora_a_stacked[index] = 0
        self.w3_lora_b_stacked[index] = 0
        self.w2_lora_a_stacked[index] = 0
        self.w2_lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: torch.Tensor | None,
        bias: torch.Tensor | None = None,
    ):
338
        self.reset_lora(index)
339
340
341
342
343
344
345
346
347
        """Overwrites lora tensors at index."""
        for eid in range(len(lora_a) // 3):
            w1_lora_a = lora_a[eid * 3]
            w2_lora_a = lora_a[eid * 3 + 1]
            w3_lora_a = lora_a[eid * 3 + 2]
            w1_lora_b = lora_b[eid * 3]
            w2_lora_b = lora_b[eid * 3 + 1]
            w3_lora_b = lora_b[eid * 3 + 2]

348
349
350
351
            # Handle the case of adding LoRA to only a subset of experts
            if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
                continue

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
            if self.tp_size > 1:
                shard_size = self.base_layer.intermediate_size_per_partition
                start_idx = self.tp_rank * shard_size
                end_idx = (self.tp_rank + 1) * shard_size

                w1_lora_b = w1_lora_b[start_idx:end_idx, :]
                w3_lora_b = w3_lora_b[start_idx:end_idx, :]
                w2_lora_a = w2_lora_a[:, start_idx:end_idx]

            self.w1_lora_a_stacked[
                index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
            ].copy_(w1_lora_a, non_blocking=True)

            self.w3_lora_a_stacked[
                index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
            ].copy_(w3_lora_a, non_blocking=True)

            self.w2_lora_b_stacked[
                index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
            ].copy_(w2_lora_b, non_blocking=True)

            self.w1_lora_b_stacked[
                index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
            ].copy_(w1_lora_b, non_blocking=True)
            self.w3_lora_b_stacked[
                index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
            ].copy_(w3_lora_b, non_blocking=True)
            self.w2_lora_a_stacked[
                index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
            ].copy_(w2_lora_a, non_blocking=True)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # return type(source_layer) is FusedMoE
        return isinstance(source_layer, FusedMoE)

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
        return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

    @property
    def _shared_experts(self):
        return self.base_layer._shared_experts

    @property
    def quant_method(self):
        return self.base_layer.quant_method
408
409
410
411

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router