Unverified Commit 8cd174fa authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA] MoE LoRA Refactor (#40338)

parent c798593f
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,31 +13,17 @@ from vllm.distributed.parallel_state import ( ...@@ -14,31 +13,17 @@ from vllm.distributed.parallel_state import (
) )
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.experts.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod, FusedMoEModularMethod,
) )
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext
FusedMoEKernel, from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoDPEPModular, MoEPrepareAndFinalizeNoDPEPModular,
) )
from .utils import _get_lora_device, try_get_optimal_moe_lora_config from .utils import _get_lora_device
class FusedMoEWithLoRA(BaseLayerWithLoRA): class FusedMoEWithLoRA(BaseLayerWithLoRA):
...@@ -58,299 +43,49 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -58,299 +43,49 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1 self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
normalized_config = {}
for key, value in config.items():
if key.islower():
if key.startswith("block_"):
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
normalized_key = key.upper()
else:
normalized_key = key
normalized_config[normalized_key] = value
return normalized_config
def _get_lora_moe_configs(
self,
op_prefix: str,
num_loras: int,
rank: int,
num_slices: int,
M: int,
layer: FusedMoE,
top_k: int,
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
intermediate_size = (
self.w2_lora_a_stacked[0].shape[-1]
if op_prefix == "w2"
else self.w13_lora_b_stacked[0].shape[-2]
)
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size,
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size,
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=layer.w13_weight.shape,
w2_shape=layer.w2_weight.shape,
rank=rank,
top_k=top_k,
dtype=config_dtype,
M=M,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_shrink"
)
expand_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_expand"
)
shrink_config = self._normalize_keys(shrink_config)
expand_config = self._normalize_keys(expand_config)
return shrink_config, expand_config
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
self.base_layer.ensure_moe_quant_config_init() self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
if getattr(self.base_layer.quant_method, "supports_internal_mk", False): if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
# Use the existing modular kernel from the quant method moe_kernel = self.base_layer.quant_method.moe_kernel
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can # Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream. # overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None moe_kernel.shared_experts = None
else: else:
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream.
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular() prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
m_fused_moe_fn = FusedMoEKernel( moe_kernel = FusedMoEKernel(
prepare_finalize, prepare_finalize,
self.base_layer.quant_method.select_gemm_impl( self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer prepare_finalize, self.base_layer
), ),
) )
assert moe_kernel.supports_lora(), (
if quant_config.use_mxfp4_w4a16: f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
assert isinstance( "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
m_fused_moe_fn.impl.fused_experts, "(auto selects Triton automatically when LoRA is enabled). "
(MarlinExperts, UnfusedOAITritonExperts), "For quantized MoE, mix LoRAExpertsMixin into the experts class "
) "and consume self._lora_context in apply()."
else:
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
num_tokens = hidden_states.size(0)
M = num_tokens
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
# activates only a small fraction of total experts * loras.
SPARSITY_FACTOR = 8
naive_block_assignment = (
expert_map is None
and num_tokens * top_k * SPARSITY_FACTOR
<= self.base_layer.local_num_experts * self.max_loras
)
# get the block size of m from customized config or default config
(
token_lora_mapping,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
self.max_loras,
self.adapter_enabled,
expert_map,
naive_block_assignment=naive_block_assignment,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
moe_state_dict["token_lora_mapping"] = token_lora_mapping
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
#
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
self.w13_lora_a_stacked,
self.w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
num_tokens = hidden_states.size(0)
M = num_tokens
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=1,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
token_lora_mapping = moe_state_dict.get("token_lora_mapping")
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
self.w2_lora_a_stacked,
self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.impl.fused_experts
m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
) )
fused_experts.moe_sum = moe_sum_decorator( self._fused_experts = moe_kernel.fused_experts
self.base_layer, fused_experts.moe_sum
)
# TODO(bnell): find a less intrusive way to handle this.
self.base_layer._replace_quant_method( self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn) FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
)
def _build_lora_context(self):
return MoELoRAContext(
w13_lora_a_stacked=self.w13_lora_a_stacked,
w13_lora_b_stacked=self.w13_lora_b_stacked,
w2_lora_a_stacked=self.w2_lora_a_stacked,
w2_lora_b_stacked=self.w2_lora_b_stacked,
adapter_enabled=self.adapter_enabled,
max_loras=self.max_loras,
top_k=self.base_layer.top_k,
w13_num_slices=self._w13_slices,
fully_sharded=self.fully_sharded,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
local_num_experts=self.base_layer.local_num_experts,
punica_wrapper=self.punica_wrapper,
use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
) )
def _create_lora_a_weights( def _create_lora_a_weights(
...@@ -589,6 +324,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -589,6 +324,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2] index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True) ].copy_(sliced_w2_lora_b, non_blocking=True)
def set_mapping(self, punica_wrapper):
super().set_mapping(punica_wrapper)
self._fused_experts.set_lora_context(self._build_lora_context())
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs) return self.base_layer.forward(*args, **kwargs)
......
...@@ -90,11 +90,12 @@ def try_get_optimal_moe_lora_config( ...@@ -90,11 +90,12 @@ def try_get_optimal_moe_lora_config(
top_k: int, top_k: int,
dtype: str | None, dtype: str | None,
M: int, M: int,
block_shape: list[int] | None = None,
) -> dict[str, int | None]: ) -> dict[str, int | None]:
config = try_get_optimal_moe_config( # LoRA shrink/expand operates on bf16/fp16 adapters regardless of the
w1_shape, w2_shape, top_k, dtype, M, block_shape # base MoE weight's block-wise quantization, so block_shape is omitted
).copy() # from the config lookup — the non-quantized branch in get_default_config
# ignores it anyway.
config = try_get_optimal_moe_config(w1_shape, w2_shape, top_k, dtype, M).copy()
if op_type in [ if op_type in [
"fused_moe_lora_w13_shrink", "fused_moe_lora_w13_shrink",
"fused_moe_lora_w2_shrink", "fused_moe_lora_w2_shrink",
......
...@@ -321,3 +321,20 @@ def supports_pdl(device: torch.device | None = None) -> bool: ...@@ -321,3 +321,20 @@ def supports_pdl(device: torch.device | None = None) -> bool:
def supports_tma(device: torch.device | None = None) -> bool: def supports_tma(device: torch.device | None = None) -> bool:
# TMA requires compute capability SM90 or above # TMA requires compute capability SM90 or above
return current_platform.is_cuda() and current_platform.has_device_capability(90) return current_platform.is_cuda() and current_platform.has_device_capability(90)
def _normalize_lora_config_keys(
config: dict[str, int | None],
) -> dict[str, int | None]:
"""Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format."""
out: dict[str, int | None] = {}
for key, val in config.items():
if key.islower():
if key.startswith("block_"):
nk = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
nk = key.upper()
else:
nk = key
out[nk] = val
return out
...@@ -493,3 +493,65 @@ class PunicaWrapperBase(PunicaWrapperABC): ...@@ -493,3 +493,65 @@ class PunicaWrapperBase(PunicaWrapperABC):
""" """
# TODO: implement it based on torch ops # TODO: implement it based on torch ops
raise NotImplementedError raise NotImplementedError
def add_lora_w13(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
local_num_experts: int,
top_k: int,
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
"""Apply w13 LoRA to y (intermediate_cache1) in-place before activation.
Returns (sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora, token_lora_mapping)
for reuse by add_lora_w2.
"""
raise NotImplementedError
def add_lora_w2(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids_lora: torch.Tensor | None,
expert_ids_lora: torch.Tensor | None,
num_tokens_post_padded_lora: torch.Tensor | None,
token_lora_mapping: torch.Tensor | None,
num_tokens: int,
w1: torch.Tensor,
w2: torch.Tensor,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
top_k: int,
fully_sharded: bool,
tp_rank: int,
use_tuned_config: bool,
) -> None:
"""Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum.
Reuses routing tensors returned by add_lora_w13.
"""
raise NotImplementedError
...@@ -459,3 +459,239 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -459,3 +459,239 @@ class PunicaWrapperGPU(PunicaWrapperBase):
fully_sharded, fully_sharded,
offset, offset,
) )
def add_lora_w13(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
local_num_experts: int,
top_k: int,
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
import functools
from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
from vllm.lora.ops.triton_ops.utils import (
_normalize_lora_config_keys,
get_lora_op_configs,
)
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str
config_dtype = _get_config_dtype_str(
dtype=x.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
max_lora_rank = lora_a_stacked[0].shape[-2]
if use_tuned_config:
shrink_config = get_lora_op_configs(
op_type="fused_moe_lora_w13_shrink",
max_loras=max_loras,
batch=num_tokens,
hidden_size=x.shape[-1],
rank=max_lora_rank,
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked[0].shape[-2],
)
expand_config = get_lora_op_configs(
op_type="fused_moe_lora_w13_expand",
max_loras=max_loras,
batch=num_tokens,
hidden_size=x.shape[-1],
rank=max_lora_rank,
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked[0].shape[-2],
)
else:
get_config = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=w1.shape,
w2_shape=w2.shape,
rank=max_lora_rank,
top_k=top_k,
dtype=config_dtype,
M=num_tokens,
)
shrink_config = get_config(op_type="fused_moe_lora_w13_shrink")
expand_config = get_config(op_type="fused_moe_lora_w13_expand")
shrink_config = _normalize_lora_config_keys(shrink_config)
expand_config = _normalize_lora_config_keys(expand_config)
SPARSITY_FACTOR = 8
naive_block_assignment = (
expert_map is None
and num_tokens * top_k * SPARSITY_FACTOR <= local_num_experts * max_loras
)
(
token_lora_mapping,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.moe_lora_align_block_size(
topk_ids,
num_tokens,
int(shrink_config.get("BLOCK_SIZE_M") or 64),
local_num_experts,
max_loras,
adapter_enabled,
expert_map,
naive_block_assignment=naive_block_assignment,
)
_sorted = sorted_token_ids_lora
_eids = expert_ids_lora
if _sorted is not None:
_eids = _eids.view(max_loras, -1)
_sorted = _sorted.view(max_loras, -1)
self.add_lora_fused_moe(
y.view(-1, top_k_num, y.shape[-1]),
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
_sorted,
_eids,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config,
expand_config,
adapter_enabled,
fully_sharded=fully_sharded,
token_lora_mapping=token_lora_mapping,
)
return (
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
)
def add_lora_w2(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids_lora: torch.Tensor | None,
expert_ids_lora: torch.Tensor | None,
num_tokens_post_padded_lora: torch.Tensor | None,
token_lora_mapping: torch.Tensor | None,
num_tokens: int,
w1: torch.Tensor,
w2: torch.Tensor,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
top_k: int,
fully_sharded: bool,
tp_rank: int,
use_tuned_config: bool,
) -> None:
import functools
from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
from vllm.lora.ops.triton_ops.utils import (
_normalize_lora_config_keys,
get_lora_op_configs,
)
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str
config_dtype = _get_config_dtype_str(
dtype=x.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
max_lora_rank = lora_a_stacked[0].shape[-2]
if use_tuned_config:
shrink_config = get_lora_op_configs(
op_type="fused_moe_lora_w2_shrink",
max_loras=max_loras,
batch=num_tokens,
hidden_size=y.shape[-1],
rank=max_lora_rank,
num_slices=1,
moe_intermediate_size=lora_a_stacked[0].shape[-1],
)
expand_config = get_lora_op_configs(
op_type="fused_moe_lora_w2_expand",
max_loras=max_loras,
batch=num_tokens,
hidden_size=y.shape[-1],
rank=max_lora_rank,
num_slices=1,
moe_intermediate_size=lora_a_stacked[0].shape[-1],
)
else:
get_config = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=w1.shape,
w2_shape=w2.shape,
rank=max_lora_rank,
top_k=top_k,
dtype=config_dtype,
M=num_tokens,
)
shrink_config = get_config(op_type="fused_moe_lora_w2_shrink")
expand_config = get_config(op_type="fused_moe_lora_w2_expand")
shrink_config = _normalize_lora_config_keys(shrink_config)
expand_config = _normalize_lora_config_keys(expand_config)
_sorted = sorted_token_ids_lora
_eids = expert_ids_lora
if _sorted is not None:
assert _eids is not None
_eids = _eids.view(max_loras, -1)
_sorted = _sorted.view(max_loras, -1)
# w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded
shard_size = lora_b_stacked[0].shape[-2]
offset = shard_size * tp_rank if fully_sharded else 0
self.add_lora_fused_moe(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
_sorted,
_eids,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config,
expand_config,
adapter_enabled,
True, # mul_routed_weight
fully_sharded=fully_sharded,
offset=offset,
token_lora_mapping=token_lora_mapping,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -16,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -16,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
...@@ -654,7 +654,7 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -654,7 +654,7 @@ class OAITritonExperts(BaseOAITritonExperts):
) )
class UnfusedOAITritonExperts(BaseOAITritonExperts): class UnfusedOAITritonExperts(LoRAExpertsMixin, BaseOAITritonExperts):
""" """
A Triton based MoE expert class that operates on expert standard A Triton based MoE expert class that operates on expert standard
format and explicitly keeps the activation and reduction (moe_sum) steps format and explicitly keeps the activation and reduction (moe_sum) steps
...@@ -721,6 +721,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -721,6 +721,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
global_topk_ids = topk_ids
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
...@@ -775,10 +776,40 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -775,10 +776,40 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
y=intermediate_cache1, y=intermediate_cache1,
) )
# w13 LoRA: gather the activation input from expert-sorted
# intermediate_cache1, then add the LoRA delta in-place on that copy
# before passing it to activation — exactly mirroring the old
# decorator approach which modified the gathered tensor in-place.
act_input = intermediate_cache1.view(-1, N)[gather_indx.dst_indx]
sorted_token_ids_lora = None
expert_ids_lora = None
num_tokens_post_padded_lora = None
token_lora_mapping = None
lora_context = self._lora_context
if lora_context is not None:
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
) = self.apply_w13_lora(
lora_context,
y=act_input,
x=hidden_states,
topk_ids=global_topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
w1=w1,
w2=w2,
num_tokens=M,
top_k_num=topk,
)
self.activation( self.activation(
activation, activation,
intermediate_cache2, intermediate_cache2,
intermediate_cache1.view(-1, N)[gather_indx.dst_indx], act_input,
) )
# matmul_ogs grouped reduction fuse sum across multiple experts: # matmul_ogs grouped reduction fuse sum across multiple experts:
...@@ -797,6 +828,24 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -797,6 +828,24 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
y=intermediate_cache3, y=intermediate_cache3,
) )
# w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is
# in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects.
if lora_context is not None:
self.apply_w2_lora(
lora_context,
y=intermediate_cache3.view(-1, topk, K),
x=intermediate_cache2,
topk_weights=topk_weights,
sorted_token_ids_lora=sorted_token_ids_lora,
expert_ids_lora=expert_ids_lora,
num_tokens_post_padded_lora=num_tokens_post_padded_lora,
token_lora_mapping=token_lora_mapping,
num_tokens=M,
w1=w1,
w2=w2,
top_k_num=topk,
)
self.moe_sum(intermediate_cache3.view(-1, topk, K), output) self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
...@@ -655,7 +656,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): ...@@ -655,7 +656,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
return E, M, N, K, topk return E, M, N, K, topk
class MarlinExperts(MarlinExpertsBase): class MarlinExperts(LoRAExpertsMixin, MarlinExpertsBase):
"""Marlin-based fused MoE expert implementation.""" """Marlin-based fused MoE expert implementation."""
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
...@@ -720,6 +721,9 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -720,6 +721,9 @@ class MarlinExperts(MarlinExpertsBase):
): ):
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
ctx = self._lora_context
if ctx is None:
fused_marlin_moe( fused_marlin_moe(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
...@@ -751,6 +755,102 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -751,6 +755,102 @@ class MarlinExperts(MarlinExpertsBase):
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_dtype=self.input_dtype, input_dtype=self.input_dtype,
) )
return
# LoRA path: wrap activation_func and moe_sum to inject LoRA at the
# two natural injection points.
#
# Marlin uses moe_align_block_size (same as TritonExperts) so
# intermediate_cache1 is indexed by flat (token, expert) pair index,
# which is compatible with add_lora_fused_moe's scatter mechanism.
M = hidden_states.size(0)
top_k_num = topk_ids.size(1)
lora_state: dict = {}
def activation_with_lora(
act_enum: MoEActivation,
act_output: torch.Tensor,
act_input: torch.Tensor,
) -> None:
# act_input = intermediate_cache1 (M*topk, 2N for gated)
# act_output = intermediate_cache2 (M*topk, N)
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
) = self.apply_w13_lora(
ctx,
y=act_input,
x=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
w1=w1,
w2=w2,
num_tokens=M,
top_k_num=top_k_num,
)
lora_state.update(
{
"sorted": sorted_token_ids_lora,
"eids": expert_ids_lora,
"npad": num_tokens_post_padded_lora,
"tlm": token_lora_mapping,
}
)
self.activation(act_enum, act_output, act_input)
lora_state["cache2"] = act_output
def moe_sum_with_lora(moe_out: torch.Tensor, out: torch.Tensor) -> None:
# moe_out shape: (M, topk, K)
self.apply_w2_lora(
ctx,
y=moe_out,
x=lora_state["cache2"],
topk_weights=topk_weights,
sorted_token_ids_lora=lora_state["sorted"],
expert_ids_lora=lora_state["eids"],
num_tokens_post_padded_lora=lora_state["npad"],
token_lora_mapping=lora_state["tlm"],
num_tokens=M,
w1=w1,
w2=w2,
top_k_num=top_k_num,
)
self.moe_sum(moe_out, out)
return fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
bias1=self.w1_bias,
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_scale1=self.g1_alphas,
global_scale2=self.g2_alphas,
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
activation_func=activation_with_lora,
moe_sum=moe_sum_with_lora,
expert_map=expert_map,
output=output,
intermediate_cache13=workspace2,
intermediate_cache2=workspace13,
g_idx1=self.w13_g_idx,
g_idx2=self.w2_g_idx,
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output) ops.moe_sum(input, output)
......
...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
_get_config_dtype_str, _get_config_dtype_str,
) )
from vllm.model_executor.layers.fused_moe.lora_experts_mixin import LoRAExpertsMixin
from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size, moe_align_block_size,
) )
...@@ -1886,7 +1887,7 @@ def fused_experts_impl( ...@@ -1886,7 +1887,7 @@ def fused_experts_impl(
return out_hidden_states return out_hidden_states
class TritonExperts(mk.FusedMoEExpertsModular): class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation.""" """Triton-based fused MoE expert implementation."""
def __init__( def __init__(
...@@ -2094,6 +2095,33 @@ class TritonExperts(mk.FusedMoEExpertsModular): ...@@ -2094,6 +2095,33 @@ class TritonExperts(mk.FusedMoEExpertsModular):
B_bias=self.w1_bias, B_bias=self.w1_bias,
) )
# LoRA w13: applied to intermediate_cache1 before activation, using
# hidden_states as the lora_a input. moe_lora_align_block_size is
# called once here and results reused for the w2 LoRA below.
sorted_token_ids_lora = None
expert_ids_lora = None
num_tokens_post_padded_lora = None
token_lora_mapping = None
lora_context = self._lora_context
if lora_context is not None:
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
) = self.apply_w13_lora(
lora_context,
y=intermediate_cache1,
x=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
w1=w1,
w2=w2,
num_tokens=num_tokens,
top_k_num=top_k_num,
)
self.activation( self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N) activation, intermediate_cache2, intermediate_cache1.view(-1, N)
) )
...@@ -2132,6 +2160,25 @@ class TritonExperts(mk.FusedMoEExpertsModular): ...@@ -2132,6 +2160,25 @@ class TritonExperts(mk.FusedMoEExpertsModular):
B_bias=self.w2_bias, B_bias=self.w2_bias,
) )
# LoRA w2: applied to intermediate_cache3 before moe_sum, using the
# unquantized intermediate_cache2 as the lora_a input. Reuses the
# sorted_token_ids_lora computed above.
if lora_context is not None:
self.apply_w2_lora(
lora_context,
y=intermediate_cache3,
x=intermediate_cache2,
topk_weights=topk_weights,
sorted_token_ids_lora=sorted_token_ids_lora,
expert_ids_lora=expert_ids_lora,
num_tokens_post_padded_lora=num_tokens_post_padded_lora,
token_lora_mapping=token_lora_mapping,
num_tokens=num_tokens,
w1=w1,
w2=w2,
top_k_num=top_k_num,
)
# separate function is required for MoE + LoRA # separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output) self.moe_sum(intermediate_cache3, output)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
@dataclass
class MoELoRAContext:
"""
Carries all LoRA state for one MoE forward pass.
Built by FusedMoEWithLoRA.forward() and propagated explicitly through the
modular kernel path (FusedMoEKernel -> FusedMoEExpertsModular.apply) so
that TritonExperts.apply() can compute the LoRA contribution inline,
replacing the decorator-based monkey-patch approach.
"""
# LoRA weight tensors (same shapes as FusedMoEWithLoRA attributes)
w13_lora_a_stacked: tuple[torch.Tensor, ...]
w13_lora_b_stacked: tuple[torch.Tensor, ...]
w2_lora_a_stacked: tuple[torch.Tensor, ...]
w2_lora_b_stacked: tuple[torch.Tensor, ...]
# (max_loras + 1,) int32; slot 0 is the "no-adapter" sentinel
adapter_enabled: torch.Tensor
# Metadata
max_loras: int
top_k: int
w13_num_slices: int # 2 = gated (gate + up), 1 = non-gated or 3D-fused
fully_sharded: bool
tp_rank: int
tp_size: int
local_num_experts: int
punica_wrapper: PunicaWrapperBase
# Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs
# try_get_optimal_moe_lora_config for Triton kernel tile configs.
use_tuned_config: bool
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.fused_moe.lora_context import MoELoRAContext
class LoRAExpertsMixin:
"""
Mixin for FusedMoEExpertsModular subclasses that natively handle
MoELoRAContext inside their apply() implementation.
Mixing this class in:
- Flips supports_lora() to True so _can_fused_experts_support lets
LoRA through the gate check.
- Stashes a MoELoRAContext on the experts instance via
set_lora_context(), which apply() consumes from self._lora_context.
- Provides apply_w13_lora / apply_w2_lora helpers that dispatch to
the PunicaWrapper kernels.
The helper methods are pure functions of their inputs; all required
state is on lora_context or passed as arguments.
"""
_lora_context: MoELoRAContext | None = None
def set_lora_context(self, ctx: MoELoRAContext) -> None:
self._lora_context = ctx
@staticmethod
def supports_lora() -> bool:
return True
def apply_w13_lora(
self,
lora_context: MoELoRAContext,
*,
y: torch.Tensor,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
return lora_context.punica_wrapper.add_lora_w13(
y,
x,
lora_context.w13_lora_a_stacked,
lora_context.w13_lora_b_stacked,
topk_ids,
topk_weights,
expert_map,
w1,
w2,
num_tokens,
top_k_num,
lora_context.max_loras,
lora_context.adapter_enabled,
lora_context.local_num_experts,
lora_context.top_k,
lora_context.w13_num_slices,
lora_context.fully_sharded,
lora_context.use_tuned_config,
)
def apply_w2_lora(
self,
lora_context: MoELoRAContext,
*,
y: torch.Tensor,
x: torch.Tensor,
topk_weights: torch.Tensor,
sorted_token_ids_lora: torch.Tensor | None,
expert_ids_lora: torch.Tensor | None,
num_tokens_post_padded_lora: torch.Tensor | None,
token_lora_mapping: torch.Tensor | None,
num_tokens: int,
w1: torch.Tensor,
w2: torch.Tensor,
top_k_num: int,
) -> None:
lora_context.punica_wrapper.add_lora_w2(
y,
x,
lora_context.w2_lora_a_stacked,
lora_context.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
num_tokens,
w1,
w2,
top_k_num,
lora_context.max_loras,
lora_context.adapter_enabled,
lora_context.top_k,
lora_context.fully_sharded,
lora_context.tp_rank,
lora_context.use_tuned_config,
)
...@@ -570,6 +570,8 @@ class FusedMoEExperts(ABC): ...@@ -570,6 +570,8 @@ class FusedMoEExperts(ABC):
return False, _make_reason(f"{activation_format.value} activation format") return False, _make_reason(f"{activation_format.value} activation format")
elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance(): elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance():
return False, _make_reason("batch invariance") return False, _make_reason("batch invariance")
elif moe_config.is_lora_enabled and not cls.supports_lora():
return False, _make_reason("LoRA")
return True, None return True, None
@staticmethod @staticmethod
...@@ -734,6 +736,15 @@ class FusedMoEExperts(ABC): ...@@ -734,6 +736,15 @@ class FusedMoEExperts(ABC):
def g2_alphas(self) -> torch.Tensor | None: def g2_alphas(self) -> torch.Tensor | None:
return self.quant_config.g2_alphas return self.quant_config.g2_alphas
@staticmethod
def supports_lora() -> bool:
"""Return True if this expert impl natively handles LoRA.
LoRA-aware experts should mix in LoRAExpertsMixin, which flips this
to True and provides the per-forward LoRA state plumbing.
"""
return False
@abstractmethod @abstractmethod
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
""" """
...@@ -1527,6 +1538,9 @@ class FusedMoEKernel: ...@@ -1527,6 +1538,9 @@ class FusedMoEKernel:
def fused_experts(self) -> FusedMoEExperts: def fused_experts(self) -> FusedMoEExperts:
return self.impl.fused_experts return self.impl.fused_experts
def supports_lora(self) -> bool:
return self.fused_experts.supports_lora()
def _post_init_setup(self): def _post_init_setup(self):
""" """
Resolve any leftover setup dependencies between self.prepare_finalize Resolve any leftover setup dependencies between self.prepare_finalize
......
...@@ -220,9 +220,6 @@ def select_fp8_moe_backend( ...@@ -220,9 +220,6 @@ def select_fp8_moe_backend(
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0]
# NOTE: the kernels are selected in the following order. # NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key) AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
......
...@@ -214,19 +214,6 @@ def select_unquantized_moe_backend( ...@@ -214,19 +214,6 @@ def select_unquantized_moe_backend(
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# LoRA needs Triton's unfused activation/reduction hooks. Selecting the
# backend here ensures weights stay in a LoRA-compatible layout instead of
# being permuted for a backend like FlashInfer or AITER during load.
if moe_config.is_lora_enabled:
backend = UnquantizedMoeBackend.TRITON
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
backend = UnquantizedMoeBackend.BATCHED_TRITON
return _return_or_raise(
backend,
moe_config,
activation_format,
)
runner_backend = moe_config.moe_backend runner_backend = moe_config.moe_backend
if runner_backend != "auto": if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend) requested_backend = map_unquantized_backend(runner_backend)
......
...@@ -297,7 +297,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -297,7 +297,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward_native( return self.forward_native(
layer, x, topk_weights, topk_ids, shared_experts_input layer,
x,
topk_weights,
topk_ids,
shared_experts_input,
) )
def apply_monolithic( def apply_monolithic(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment