Unverified Commit 5f6cbf60 authored by Chen Wu's avatar Chen Wu Committed by GitHub
Browse files

[Feature][Kernel]FusedMoE LoRA (#21229)


Signed-off-by: default avatarwuchen <cntryroa@gmail.com>
Signed-off-by: default avatarbanjuede <lmklhc@163.com>
Signed-off-by: default avatarChen Wu <cntryroa@gmail.com>
Signed-off-by: default avatarDanielle Robinson <dmmaddix@amazon.com>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarbk-201 <joy25810@foxmail.com>
Co-authored-by: default avatarwuchen <wuchen@zetyun.com>
Co-authored-by: default avatarNathan Van Gheem <vangheem@gmail.com>
Co-authored-by: default avatarbanjuede <lmklhc@163.com>
Co-authored-by: default avatarDanielle Robinson <dmmaddix@amazon.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarbk-201 <joy25810@foxmail.com>
parent 3ada34f9
...@@ -23,6 +23,7 @@ from vllm.lora.layers import ( ...@@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA, BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA, ColumnParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA, LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
...@@ -35,7 +36,9 @@ from vllm.lora.layers import ( ...@@ -35,7 +36,9 @@ from vllm.lora.layers import (
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA, VocabParallelEmbeddingWithLoRA,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { ...@@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
} }
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer( def from_layer(
layer: nn.Module, layer: nn.Module,
max_loras: int, max_loras: int,
...@@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: ...@@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
if isinstance(module, (LinearBase,)): if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1]) supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules) return list(supported_lora_modules)
...@@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str: ...@@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str:
return lora_path return lora_path
return local_snapshot_path return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)
...@@ -94,7 +94,8 @@ class WorkerLoRAManager: ...@@ -94,7 +94,8 @@ class WorkerLoRAManager:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_modules.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules)) expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ.""" """Fused MoE utilities for GPTQ."""
from collections.abc import Callable
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
...@@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size, batched_moe_align_block_size,
moe_align_block_size, moe_align_block_size,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
...@@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe( def _fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -36,12 +56,15 @@ def _fused_marlin_moe( ...@@ -36,12 +56,15 @@ def _fused_marlin_moe(
num_topk: int, num_topk: int,
quant_type: ScalarType, quant_type: ScalarType,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
activation: str,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
block_size_m: int, block_size_m: int,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
...@@ -118,19 +141,8 @@ def _fused_marlin_moe( ...@@ -118,19 +141,8 @@ def _fused_marlin_moe(
is_zp_float=False, is_zp_float=False,
) )
if activation == "silu": activation_func(
torch.ops._C.silu_and_mul( activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
) )
if output is None: if output is None:
...@@ -185,7 +197,11 @@ def fused_marlin_moe( ...@@ -185,7 +197,11 @@ def fused_marlin_moe(
quant_type_id: int, quant_type_id: int,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
activation: str | None = "silu", activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
...@@ -290,12 +306,13 @@ def fused_marlin_moe( ...@@ -290,12 +306,13 @@ def fused_marlin_moe(
num_topk=topk, num_topk=topk,
quant_type=quant_type, quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map, expert_map=expert_map,
block_size_m=block_size_m, block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids, sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids, expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded, num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
global_scale1=global_scale1, global_scale1=global_scale1,
global_scale2=global_scale2, global_scale2=global_scale2,
g_idx1=g_idx1, g_idx1=g_idx1,
...@@ -317,7 +334,10 @@ def fused_marlin_moe( ...@@ -317,7 +334,10 @@ def fused_marlin_moe(
else: else:
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
else:
return moe_sum(moe_output, output)
def batched_fused_marlin_moe( def batched_fused_marlin_moe(
...@@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
activation=activation, activation=activation,
activation_func=self.activation,
moe_sum=self.moe_sum,
expert_map=expert_map, expert_map=expert_map,
output=output, output=output,
# Workspaces are swapped in workspace_shapes() to account for proper # Workspaces are swapped in workspace_shapes() to account for proper
...@@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase): ...@@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase):
intermediate_cache2=workspace13, intermediate_cache2=workspace13,
) )
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase): class BatchedMarlinExperts(MarlinExpertsBase):
def __init__( def __init__(
......
...@@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
B_bias=self.w2_bias, B_bias=self.w2_bias,
) )
ops.moe_sum(intermediate_cache3, output) # separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_triton_fused_moe( def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel( return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config), TritonExperts(quant_config),
shared_experts,
) )
...@@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
torch.ops._C.silu_and_mul(output, input) torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu": elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input) torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
......
...@@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR ...@@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk ...@@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
from .interfaces import SupportsEagle3, SupportsPP from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
...@@ -627,7 +627,7 @@ class GptOssModel(nn.Module): ...@@ -627,7 +627,7 @@ class GptOssModel(nn.Module):
) )
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
...@@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): ...@@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
......
...@@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
...@@ -349,8 +349,6 @@ class OlmoeModel(nn.Module): ...@@ -349,8 +349,6 @@ class OlmoeModel(nn.Module):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -433,17 +431,13 @@ class OlmoeModel(nn.Module): ...@@ -433,17 +431,13 @@ class OlmoeModel(nn.Module):
return loaded_params return loaded_params
class OlmoeForCausalLM(nn.Module, SupportsPP): class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
"k_proj", "k_proj",
"v_proj", "v_proj",
], ]
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment