Commit a6bed85b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-fix_moe_quant' into 'v0.9.2-dev'

fix: moe quant bug and attention bug

See merge request dcutoolkit/deeplearing/vllm!275
parents 2fc5b0bb 8cfbe041
...@@ -172,8 +172,8 @@ if TYPE_CHECKING: ...@@ -172,8 +172,8 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = True USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = True USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
...@@ -1158,12 +1158,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1158,12 +1158,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))), lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vllm will use silu_mul_quant fused op, # vllm will use silu_mul_quant fused op,
# This variable has a default value of true, # This variable has a default value of true,
# but it is still controlled by CRQ and RQ. # but it is still controlled by CRQ and RQ.
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "1"))), lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))),
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
......
...@@ -377,7 +377,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -377,7 +377,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False, **_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -1542,34 +1542,61 @@ class FusedMoE(torch.nn.Module): ...@@ -1542,34 +1542,61 @@ class FusedMoE(torch.nn.Module):
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
# Matrix multiply. if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply( # Matrix multiply.
layer=self, final_hidden_states = self.quant_method.apply(
x=hidden_states, layer=self,
router_logits=router_logits, x=hidden_states,
top_k=self.top_k, router_logits=router_logits,
renormalize=self.renormalize, top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk, renormalize=self.renormalize,
global_num_experts=self.global_num_experts, use_grouped_topk=self.use_grouped_topk,
expert_map=self.expert_map, global_num_experts=self.global_num_experts,
topk_group=self.topk_group, expert_map=self.expert_map,
num_expert_group=self.num_expert_group, topk_group=self.topk_group,
custom_routing_function=self.custom_routing_function, num_expert_group=self.num_expert_group,
scoring_func=self.scoring_func, custom_routing_function=self.custom_routing_function,
e_score_correction_bias=self.e_score_correction_bias, scoring_func=self.scoring_func,
activation=self.activation, e_score_correction_bias=self.e_score_correction_bias,
apply_router_weight_on_input=self.apply_router_weight_on_input, activation=self.activation,
enable_eplb=self.enable_eplb, apply_router_weight_on_input=self.apply_router_weight_on_input,
expert_load_view=self.expert_load_view, enable_eplb=self.enable_eplb,
logical_to_physical_map=self.logical_to_physical_map, expert_load_view=self.expert_load_view,
logical_replica_count=self.logical_replica_count, logical_to_physical_map=self.logical_to_physical_map,
shared_output=shared_output, logical_replica_count=self.logical_replica_count,
use_nn_moe=self.use_nn_moe, shared_output=shared_output,
routed_scaling_factor=self.routed_scaling_factor, use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate, routed_scaling_factor=self.routed_scaling_factor,
i_q=i_q, use_fused_gate=self.use_fused_gate,
i_s=i_s i_q=i_q,
) i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate,
)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states) final_hidden_states = get_ep_group().combine(final_hidden_states)
...@@ -1645,8 +1672,11 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -1645,8 +1672,11 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s) return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
else:
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
......
...@@ -576,6 +576,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -576,6 +576,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._v_scale layer._k_scale, layer._v_scale
) )
else: else:
from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
value, value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment