Unverified Commit 5f6cbf60 authored by Chen Wu's avatar Chen Wu Committed by GitHub
Browse files

[Feature][Kernel]FusedMoE LoRA (#21229)


Signed-off-by: default avatarwuchen <cntryroa@gmail.com>
Signed-off-by: default avatarbanjuede <lmklhc@163.com>
Signed-off-by: default avatarChen Wu <cntryroa@gmail.com>
Signed-off-by: default avatarDanielle Robinson <dmmaddix@amazon.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarbk-201 <joy25810@foxmail.com>
Co-authored-by: default avatarwuchen <wuchen@zetyun.com>
Co-authored-by: default avatarNathan Van Gheem <vangheem@gmail.com>
Co-authored-by: default avatarbanjuede <lmklhc@163.com>
Co-authored-by: default avatarDanielle Robinson <dmmaddix@amazon.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarbk-201 <joy25810@foxmail.com>
parent 3ada34f9
......@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
......@@ -35,7 +36,9 @@ from vllm.lora.layers import (
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor
......@@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
}
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer(
layer: nn.Module,
max_loras: int,
......@@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules)
......@@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str:
return lora_path
return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)
......@@ -94,7 +94,8 @@ class WorkerLoRAManager:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path)
......
......@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ."""
from collections.abc import Callable
import torch
import vllm._custom_ops as ops
......@@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
......@@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -36,12 +56,15 @@ def _fused_marlin_moe(
num_topk: int,
quant_type: ScalarType,
apply_router_weight_on_input: bool,
activation: str,
expert_map: torch.Tensor | None,
block_size_m: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
......@@ -118,20 +141,9 @@ def _fused_marlin_moe(
is_zp_float=False,
)
if activation == "silu":
torch.ops._C.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
activation_func(
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
if output is None:
output = intermediate_cache3
......@@ -185,7 +197,11 @@ def fused_marlin_moe(
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: str | None = "silu",
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
......@@ -290,12 +306,13 @@ def fused_marlin_moe(
num_topk=topk,
quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map,
block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
......@@ -317,7 +334,10 @@ def fused_marlin_moe(
else:
output = torch.empty_like(hidden_states)
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
else:
return moe_sum(moe_output, output)
def batched_fused_marlin_moe(
......@@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
activation_func=self.activation,
moe_sum=self.moe_sum,
expert_map=expert_map,
output=output,
# Workspaces are swapped in workspace_shapes() to account for proper
......@@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase):
intermediate_cache2=workspace13,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
......
......@@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
B_bias=self.w2_bias,
)
ops.moe_sum(intermediate_cache3, output)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig,
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config),
shared_experts,
)
......@@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
......
......@@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
......@@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .interfaces import SupportsEagle3, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
......@@ -627,7 +627,7 @@ class GptOssModel(nn.Module):
)
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
......@@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
......
......@@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
......@@ -349,8 +349,6 @@ class OlmoeModel(nn.Module):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
......@@ -433,17 +431,13 @@ class OlmoeModel(nn.Module):
return loaded_params
class OlmoeForCausalLM(nn.Module, SupportsPP):
class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
]
}
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment