Commit 989a0a2b authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

feat: Support enable rms quant and shared expert overlap at same time.

parent cc946d6e
...@@ -1559,12 +1559,22 @@ class FusedMoE(torch.nn.Module): ...@@ -1559,12 +1559,22 @@ class FusedMoE(torch.nn.Module):
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None: if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(
self.layer_name, shared_output, hidden_states = hidden_states,
i_q, i_s) router_logits = router_logits,
layer_name = self.layer_name,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
else: else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, return torch.ops.vllm.moe_forward_shared(
self.layer_name, hidden_states_copy, shared_output) hidden_states = hidden_states,
router_logits = router_logits,
layer_name = self.layer_name,
hidden_states_copy = hidden_states_copy,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1730,24 +1740,27 @@ class FusedMoE(torch.nn.Module): ...@@ -1730,24 +1740,27 @@ class FusedMoE(torch.nn.Module):
use_fused_gate=self.use_fused_gate, use_fused_gate=self.use_fused_gate,
) )
if enable_shared_experts_overlap: if enable_shared_experts_overlap:
assert self.shared_experts is not None assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream # Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the # NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is # sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda # important to avoid excessive stream allocations by the cuda
# graph replay later. # graph replay later.
with torch.cuda.stream(self.shared_experts_stream): with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid # Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream # conflict with the main stream
assert hidden_states_copy is not None assert hidden_states_copy is not None
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states_copy, iqis=(i_q, i_s))
else:
shared_output = self.shared_experts(hidden_states_copy) shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream) torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,
) )
def combine_output(states: torch.Tensor) -> torch.Tensor: def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
...@@ -1762,7 +1775,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1762,7 +1775,7 @@ class FusedMoE(torch.nn.Module):
states) states)
return states return states
if enable_shared_experts_overlap and not envs.USE_FUSED_RMS_QUANT: if enable_shared_experts_overlap:
return ( return (
final_hidden_states[0], final_hidden_states[0],
combine_output(final_hidden_states[1]), combine_output(final_hidden_states[1]),
...@@ -1855,24 +1868,31 @@ direct_register_custom_op( ...@@ -1855,24 +1868,31 @@ direct_register_custom_op(
) )
def moe_forward_shared( def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None, hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, hidden_states_copy, shared_output) if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy, i_q = i_q, i_s = i_s)
else:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy)
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None, hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -34,7 +34,10 @@ class SharedFusedMoE(FusedMoE): ...@@ -34,7 +34,10 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, **_ hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped: if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
...@@ -55,5 +58,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -55,5 +58,7 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
hidden_states_copy = hidden_states_copy, hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
) )
return fused_out return fused_out
...@@ -200,7 +200,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -200,7 +200,6 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.USE_FUSED_RMS_QUANT
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None) and config.n_shared_experts is not None)
...@@ -280,19 +279,27 @@ class DeepseekV2MoE(nn.Module): ...@@ -280,19 +279,27 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None, xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None # iq = input quant, is = input scale
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
def shared_exprts_overlap_pass( def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
hidden_states_copy = hidden_states.clone() hidden_states_copy = hidden_states.clone()
return self.experts( return self.experts(
hidden_states=hidden_states, hidden_states = hidden_states,
router_logits=router_logits, router_logits = router_logits,
hidden_states_copy = hidden_states_copy) hidden_states_copy = hidden_states_copy,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
...@@ -339,7 +346,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -339,7 +346,8 @@ class DeepseekV2MoE(nn.Module):
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: # RQ else: # RQ
if not self.enable_expert_parallel: if not self.enable_expert_parallel:
i_q, i_s = None, None assert iqis is not None
i_q, i_s = iqis
if self.run_shared_expert_singlely: if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states, iqis=iqis) shared_output = self.shared_experts(hidden_states, iqis=iqis)
...@@ -350,7 +358,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -350,7 +358,7 @@ class DeepseekV2MoE(nn.Module):
if self.enable_shared_experts_overlap: if self.enable_shared_experts_overlap:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
...@@ -753,9 +761,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -753,9 +761,7 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -772,9 +778,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -772,9 +778,6 @@ class DeepseekV2MLAAttention(nn.Module):
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM: if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
# rms_weight=rms_weight, residual=residual, update_hd=False
qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis) qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank] q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:] kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment