Commit 4f11b099 authored by wujl5's avatar wujl5
Browse files

fix: 修复融合moe.quant对其他模型的多余传参影响

parent 48742057
......@@ -172,8 +172,8 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = True
USE_FUSED_SILU_MUL_QUANT: bool = True
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
......@@ -1158,12 +1158,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))),
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vllm will use silu_mul_quant fused op,
# This variable has a default value of true,
# but it is still controlled by CRQ and RQ.
"USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "1"))),
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
......
......@@ -377,7 +377,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False,
use_fused_gate: Optional[bool] = False, **_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......@@ -1542,34 +1542,61 @@ class FusedMoE(torch.nn.Module):
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s
)
if envs.USE_FUSED_RMS_QUANT:
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate,
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
......@@ -1645,8 +1672,11 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
else:
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment