Commit 989a0a2b authored by wanglong3's avatar wanglong3 Committed by zhuwenwen
Browse files

feat: Support enable rms quant and shared expert overlap at same time.

parent cc946d6e
......@@ -1559,12 +1559,22 @@ class FusedMoE(torch.nn.Module):
return self.forward_impl(hidden_states, router_logits)
else:
if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output,
i_q, i_s)
return torch.ops.vllm.moe_forward(
hidden_states = hidden_states,
router_logits = router_logits,
layer_name = self.layer_name,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, hidden_states_copy, shared_output)
return torch.ops.vllm.moe_forward_shared(
hidden_states = hidden_states,
router_logits = router_logits,
layer_name = self.layer_name,
hidden_states_copy = hidden_states_copy,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1730,24 +1740,27 @@ class FusedMoE(torch.nn.Module):
use_fused_gate=self.use_fused_gate,
)
if enable_shared_experts_overlap:
assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
assert hidden_states_copy is not None
if enable_shared_experts_overlap:
assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
assert hidden_states_copy is not None
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states_copy, iqis=(i_q, i_s))
else:
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
......@@ -1762,7 +1775,7 @@ class FusedMoE(torch.nn.Module):
states)
return states
if enable_shared_experts_overlap and not envs.USE_FUSED_RMS_QUANT:
if enable_shared_experts_overlap:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
......@@ -1855,24 +1868,31 @@ direct_register_custom_op(
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, hidden_states_copy, shared_output)
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy, i_q = i_q, i_s = i_s)
else:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
......
......@@ -34,7 +34,10 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, **_
hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
......@@ -55,5 +58,7 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
)
return fused_out
......@@ -200,7 +200,6 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.shared_experts",
)
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.USE_FUSED_RMS_QUANT
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None)
......@@ -280,19 +279,27 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None # iq = input quant, is = input scale
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
hidden_states_copy = hidden_states.clone()
return self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy)
hidden_states = hidden_states,
router_logits = router_logits,
hidden_states_copy = hidden_states_copy,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
......@@ -339,7 +346,8 @@ class DeepseekV2MoE(nn.Module):
* (1. / self.routed_scaling_factor)
else: # RQ
if not self.enable_expert_parallel:
i_q, i_s = None, None
assert iqis is not None
i_q, i_s = iqis
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states, iqis=iqis)
......@@ -350,7 +358,7 @@ class DeepseekV2MoE(nn.Module):
if self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
......@@ -753,9 +761,7 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward(
self,
positions: torch.Tensor,
......@@ -772,9 +778,6 @@ class DeepseekV2MLAAttention(nn.Module):
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None:
# rms_weight=rms_weight, residual=residual, update_hd=False
qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment