Commit 1016e339 authored by 王敏's avatar 王敏
Browse files

[feat]量化模型添加支持 moe_fused_gate kernel,并使用VLLM_ENABLE_MOE_FUSED_GATE环境变量控制开关,默认打开

parent df877aad
......@@ -127,6 +127,7 @@ if TYPE_CHECKING:
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -819,6 +820,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
}
# end-env-vars-definition
......
......@@ -19,6 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -175,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
return self.forward(
x=x,
......@@ -192,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
def forward_cuda(
self,
......@@ -212,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -224,7 +232,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor if hasattr(self, "routed_scaling_factor") else None)
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
return fused_experts(
hidden_states=x,
......@@ -257,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**kwargs,
):
assert activation == "silu", f"{activation} is not supported."
......@@ -292,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
......@@ -558,10 +571,16 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
setattr(self.quant_method, "routed_scaling_factor", self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
......@@ -845,19 +864,13 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None \
and router_logits.shape[1] // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0]):
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
if use_fused_gate:
topk_weights, topk_ids = ops.moe_fused_gate(
router_logits,
e_score_correction_bias,
......@@ -947,7 +960,9 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate
)
if self.dp_size > 1:
......
......@@ -384,6 +384,8 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -399,7 +401,9 @@ class BlockInt8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
# Expert fusion with INT8 quantization
......
......@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
......@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
......
......@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment