"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "87cf88ed3d351cc4e2b7cdd462ddf7a4ebf2109e"
Unverified Commit 27778010 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix dual stream bug (#10352)

parent 46d8fb1c
...@@ -62,6 +62,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -62,6 +62,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
...@@ -194,7 +195,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -194,7 +195,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states.clone())
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
router_output = self._forward_router_experts(hidden_states) router_output = self._forward_router_experts(hidden_states)
...@@ -217,6 +218,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -217,6 +218,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.alt_stream is not None self.alt_stream is not None
and hidden_states.shape[0] > 0 and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
and get_is_capture_mode()
): ):
final_hidden_states, shared_output = self.forward_normal_dual_stream( final_hidden_states, shared_output = self.forward_normal_dual_stream(
hidden_states hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment