Commit 9c15f410 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_fix' into 'v0.9.2-dev'

feat: Support enable rms quant and shared expert overlap at same time.

See merge request dcutoolkit/deeplearing/vllm!352
parents cc946d6e 989a0a2b
...@@ -1559,12 +1559,22 @@ class FusedMoE(torch.nn.Module): ...@@ -1559,12 +1559,22 @@ class FusedMoE(torch.nn.Module):
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None: if self.shared_experts is None:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(
self.layer_name, shared_output, hidden_states = hidden_states,
i_q, i_s) router_logits = router_logits,
layer_name = self.layer_name,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
else: else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, return torch.ops.vllm.moe_forward_shared(
self.layer_name, hidden_states_copy, shared_output) hidden_states = hidden_states,
router_logits = router_logits,
layer_name = self.layer_name,
hidden_states_copy = hidden_states_copy,
shared_output = shared_output,
i_q = i_q,
i_s = i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1741,6 +1751,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1741,6 +1751,9 @@ class FusedMoE(torch.nn.Module):
# Note that hidden_states clone() is necessary here to avoid # Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream # conflict with the main stream
assert hidden_states_copy is not None assert hidden_states_copy is not None
if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states_copy, iqis=(i_q, i_s))
else:
shared_output = self.shared_experts(hidden_states_copy) shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream) torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
...@@ -1762,7 +1775,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1762,7 +1775,7 @@ class FusedMoE(torch.nn.Module):
states) states)
return states return states
if enable_shared_experts_overlap and not envs.USE_FUSED_RMS_QUANT: if enable_shared_experts_overlap:
return ( return (
final_hidden_states[0], final_hidden_states[0],
combine_output(final_hidden_states[1]), combine_output(final_hidden_states[1]),
...@@ -1859,12 +1872,17 @@ def moe_forward_shared( ...@@ -1859,12 +1872,17 @@ def moe_forward_shared(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None, hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, hidden_states_copy, shared_output) if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy, i_q = i_q, i_s = i_s)
else:
return self.forward_impl(hidden_states, router_logits, hidden_states_copy = hidden_states_copy)
def moe_forward_shared_fake( def moe_forward_shared_fake(
...@@ -1872,7 +1890,9 @@ def moe_forward_shared_fake( ...@@ -1872,7 +1890,9 @@ def moe_forward_shared_fake(
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None, hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -34,7 +34,10 @@ class SharedFusedMoE(FusedMoE): ...@@ -34,7 +34,10 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, **_ hidden_states_copy: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped: if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
...@@ -55,5 +58,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -55,5 +58,7 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
hidden_states_copy = hidden_states_copy, hidden_states_copy = hidden_states_copy,
i_s = i_s,
i_q = i_q,
) )
return fused_out return fused_out
...@@ -200,7 +200,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -200,7 +200,6 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.USE_FUSED_RMS_QUANT
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None) and config.n_shared_experts is not None)
...@@ -280,19 +279,27 @@ class DeepseekV2MoE(nn.Module): ...@@ -280,19 +279,27 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None, xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None # iq = input quant, is = input scale
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
def shared_exprts_overlap_pass( def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor, router_logits: torch.Tensor,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = None, None
if envs.USE_FUSED_RMS_QUANT:
assert iqis is not None
i_q, i_s = iqis[0], iqis[1]
hidden_states_copy = hidden_states.clone() hidden_states_copy = hidden_states.clone()
return self.experts( return self.experts(
hidden_states=hidden_states, hidden_states = hidden_states,
router_logits=router_logits, router_logits = router_logits,
hidden_states_copy = hidden_states_copy) hidden_states_copy = hidden_states_copy,
i_q = i_q,
i_s = i_s)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
...@@ -339,7 +346,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -339,7 +346,8 @@ class DeepseekV2MoE(nn.Module):
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: # RQ else: # RQ
if not self.enable_expert_parallel: if not self.enable_expert_parallel:
i_q, i_s = None, None assert iqis is not None
i_q, i_s = iqis
if self.run_shared_expert_singlely: if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output = self.shared_experts(hidden_states, iqis=iqis) shared_output = self.shared_experts(hidden_states, iqis=iqis)
...@@ -350,7 +358,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -350,7 +358,7 @@ class DeepseekV2MoE(nn.Module):
if self.enable_shared_experts_overlap: if self.enable_shared_experts_overlap:
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits, iqis = iqis)
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
...@@ -754,8 +762,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -754,8 +762,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -772,9 +778,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -772,9 +778,6 @@ class DeepseekV2MLAAttention(nn.Module):
if envs.USE_FUSED_RMS_QUANT and iqis is not None: if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM: if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
# rms_weight=rms_weight, residual=residual, update_hd=False
qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis) qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank] q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:] kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment