Unverified Commit 9e2f7252 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

add dual stream for qwen2_moe (#10252)

parent 21176b00
...@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, is_cuda, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
def __init__( def __init__(
...@@ -122,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -122,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id: int, layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id self.layer_id = layer_id
self.alt_stream = alt_stream
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
...@@ -168,14 +172,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -168,14 +172,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.shared_expert = None self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward( def _forward_shared_experts(self, hidden_states: torch.Tensor):
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(hidden_states)
...@@ -183,11 +180,51 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -183,11 +180,51 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
shared_output = ( shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
) )
return shared_output
def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output) return self.experts(hidden_states, topk_output)
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
router_output = self._forward_router_experts(hidden_states)
current_stream.wait_stream(self.alt_stream)
return router_output, shared_output
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
final_hidden_states, shared_output = self.forward_normal_dual_stream(
hidden_states
)
else:
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self._forward_router_experts(hidden_states)
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1 and not use_reduce_scatter: if self.tp_size > 1 and not use_reduce_scatter:
...@@ -346,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -346,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
else: else:
...@@ -528,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -528,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module):
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
alt_stream = torch.cuda.Stream() if _is_cuda else None
self.model = Qwen2MoeModel( self.model = Qwen2MoeModel(
config, quant_config, prefix=add_prefix("model", prefix) config,
quant_config,
prefix=add_prefix("model", prefix),
alt_stream=alt_stream,
) )
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment