Unverified Commit 587b4c6e authored by yilian49's avatar yilian49 Committed by GitHub
Browse files

EPLB support for MTP (#7510)

parent 7b9a174a
...@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC): ...@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
def with_debug_name(self, debug_name): def with_debug_name(self, debug_name):
yield yield
@contextmanager
def disable_this_region(self):
yield
@contextmanager @contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch): def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield yield
...@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): ...@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
self._expert_location_metadata = expert_location_metadata self._expert_location_metadata = expert_location_metadata
self._recording = False self._recording = False
self._disable_all = False
self._current_forward_pass_id = Withable() self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable() self._current_layer_idx = Withable()
self._current_debug_name = Withable() self._current_debug_name = Withable()
...@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): ...@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
finally: finally:
self._on_forward_pass_end(forward_pass_id) self._on_forward_pass_end(forward_pass_id)
@contextmanager
def disable_this_region(self):
"""Context manager to temporarily disable recording."""
previous_disable_all = self._disable_all
self._disable_all = True
try:
yield
finally:
self._disable_all = previous_disable_all
def _on_forward_pass_start(self, forward_batch: ForwardBatch): def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording: if not self._recording:
return return
...@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): ...@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
) )
def _on_hook(self, hook_name: str, **kwargs): def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()): if not (self._recording or torch.cuda.is_current_stream_capturing()):
return return
gatherer = self._single_pass_gatherers[ gatherer = self._single_pass_gatherers[
...@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer): ...@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten() topk_ids = topk_ids.flatten()
mask = topk_ids != -1 mask = topk_ids != -1
assert self._data[layer_idx, :].shape == topk_ids.shape, (
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self._data[layer_idx, :].scatter_add_( self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int() dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
) )
......
...@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -28,6 +28,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
...@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module): ...@@ -82,7 +85,6 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator( zero_allocator = BumpAllocator(
buffer_size=2, buffer_size=2,
dtype=torch.float32, dtype=torch.float32,
...@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module): ...@@ -108,9 +110,10 @@ class DeepseekModelNextN(nn.Module):
) )
residual = None residual = None
hidden_states, residual = self.decoder( with get_global_expert_distribution_recorder().disable_this_region():
positions, hidden_states, forward_batch, residual, zero_allocator hidden_states, residual = self.decoder(
) positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is not None: if residual is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment