"git@developer.sourcefind.cn:change/sglang.git" did not exist on "f55933e1cc50263e0bcde65c3b78969b56225c7f"
Unverified Commit 24dc2bee authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Fix Bailing MoE model bugs (#10362)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
parent fac07c9b
......@@ -128,7 +128,9 @@ class BailingMoEMLP(nn.Module):
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
hidden_states, _ = self.down_proj(hidden_states)
hidden_states, _ = self.down_proj(
hidden_states, skip_all_reduce=use_reduce_scatter
)
return hidden_states
......@@ -328,7 +330,7 @@ class BailingMoESparseMoeBlock(nn.Module):
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
shared_output = self._forward_shared_experts(hidden_states.clone())
with torch.cuda.stream(self.alt_stream):
router_output = self._forward_router_experts(hidden_states)
......@@ -347,8 +349,9 @@ class BailingMoESparseMoeBlock(nn.Module):
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and num_tokens > 0
and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
and get_is_capture_mode()
):
final_hidden_states, shared_output = self.forward_normal_dual_stream(
hidden_states
......
......@@ -757,7 +757,7 @@ class ServerArgs:
if model_arch in [
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeV2ForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
# Auto set draft_model_path DeepSeek-V3/R1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment