Unverified Commit 3b2680a4 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Overlap shared expert and routed expert computations (#5121)

parent 79961afa
...@@ -90,7 +90,7 @@ class LlamaMLP(nn.Module): ...@@ -90,7 +90,7 @@ class LlamaMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x, forward_batch=None):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
......
...@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
...@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module): ...@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.device_module = torch.get_device_module()
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
...@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module): ...@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
reduce_results=False, # We need to do scatter before reduce reduce_results=False, # We need to do scatter before reduce
) )
def forward(self, hidden_states): def forward(self, hidden_states, forward_batch: ForwardBatch):
shared_out, routed_out = self._forward_core(
hidden_states, forward_batch.forward_mode
)
out_aD = routed_out + shared_out
if self.tp_size > 1:
out_aD = tensor_model_parallel_all_reduce(out_aD)
return out_aD
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
if hidden_states.shape[0] < 4:
return self._forward_core_shared_routed_overlap(hidden_states)
else:
return self._forward_core_normal(hidden_states)
def _forward_core_normal(self, hidden_states):
# router_scores: [num_tokens, num_experts] # router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states) shared_out = self.shared_expert(hidden_states)
...@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module): ...@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
out_aD = routed_out + shared_out return shared_out, routed_out
if self.tp_size > 1: def _forward_core_shared_routed_overlap(self, hidden_states):
out_aD = tensor_model_parallel_all_reduce(out_aD) alt_stream = _get_or_create_alt_stream(self.device_module)
return out_aD alt_stream.wait_stream(self.device_module.current_stream())
shared_out = self.shared_expert(hidden_states)
with self.device_module.stream(alt_stream):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
self.device_module.current_stream().wait_stream(alt_stream)
return shared_out, routed_out
_alt_stream = None
def _get_or_create_alt_stream(device_module):
global _alt_stream
if _alt_stream is None:
_alt_stream = device_module.Stream()
return _alt_stream
class Llama4Attention(nn.Module): class Llama4Attention(nn.Module):
...@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
hidden_states = self.feed_forward(hidden_states) hidden_states = self.feed_forward(hidden_states, forward_batch)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment