Commit 1016e339 authored by 王敏's avatar 王敏
Browse files

[feat]量化模型添加支持 moe_fused_gate kernel,并使用VLLM_ENABLE_MOE_FUSED_GATE环境变量控制开关,默认打开

parent df877aad
...@@ -127,6 +127,7 @@ if TYPE_CHECKING: ...@@ -127,6 +127,7 @@ if TYPE_CHECKING:
VLLM_FLASH_ATTN_BACKEND: bool = False VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -798,7 +799,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -798,7 +799,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the draft model in cudagraph mode. # If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD": "VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")), lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# 'has_comtext' is a variable in common.py, which is calculated # 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization # by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False. # and affect performance, so it is directly assigned as False.
...@@ -806,12 +807,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -806,12 +807,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# to restore the default usage. # to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT": "VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))), lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm # If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND": "VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")), ("true", "1")),
# Enable two batch overlap. # Enable two batch overlap.
"VLLM_ENABLE_TBO": "VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
...@@ -819,6 +820,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -819,6 +820,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable zero overhead scheduler. # Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD": "VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -19,6 +19,8 @@ from vllm.logger import init_logger ...@@ -19,6 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -175,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -175,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -192,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -192,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
def forward_cuda( def forward_cuda(
self, self,
...@@ -212,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -212,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -224,7 +232,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -224,7 +232,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor if hasattr(self, "routed_scaling_factor") else None) routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
...@@ -257,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -257,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**kwargs, **kwargs,
): ):
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
...@@ -292,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -292,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
...@@ -558,10 +571,16 @@ class FusedMoE(torch.nn.Module): ...@@ -558,10 +571,16 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
setattr(self.quant_method, "routed_scaling_factor", self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -845,19 +864,13 @@ class FusedMoE(torch.nn.Module): ...@@ -845,19 +864,13 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,): routed_scaling_factor: Optional[float] = None,
from vllm.model_executor.layers.fused_moe.fused_moe import ( use_fused_gate: Optional[bool] = False):
fused_topk, grouped_topk, is_power_of_two)
# DeekSeekv2 uses grouped_top_k # DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if e_score_correction_bias is not None \ if use_fused_gate:
and router_logits.shape[1] // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0]):
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
topk_weights, topk_ids = ops.moe_fused_gate( topk_weights, topk_ids = ops.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
...@@ -947,7 +960,9 @@ class FusedMoE(torch.nn.Module): ...@@ -947,7 +960,9 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -384,6 +384,8 @@ class BlockInt8MoEMethod: ...@@ -384,6 +384,8 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -399,7 +401,9 @@ class BlockInt8MoEMethod: ...@@ -399,7 +401,9 @@ class BlockInt8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
) )
# Expert fusion with INT8 quantization # Expert fusion with INT8 quantization
......
...@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
...@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
......
...@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod: ...@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod: ...@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
) )
return fused_experts( return fused_experts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment