Unverified Commit 0d658ac3 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Support recording experts workload in QWen2-MoE (#4775)

parent ced35a06
...@@ -44,10 +44,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -44,10 +44,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
expert_distribution_recorder = ExpertDistributionRecorder()
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
def __init__( def __init__(
...@@ -366,6 +369,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -366,6 +369,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states = input_embeds hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment