Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f0653886
Unverified
Commit
f0653886
authored
May 20, 2025
by
fzyzcjy
Committed by
GitHub
May 19, 2025
Browse files
Expert distribution recording without overhead for EPLB (#4957)
parent
b1465557
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1119 additions
and
190 deletions
+1119
-190
docs/backend/native_api.ipynb
docs/backend/native_api.ipynb
+2
-14
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+14
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+5
-4
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+595
-56
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+273
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+47
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+20
-9
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+18
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+32
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+35
-1
test/srt/test_expert_distribution.py
test/srt/test_expert_distribution.py
+71
-92
No files found.
docs/backend/native_api.ipynb
View file @
f0653886
...
...
@@ -390,7 +390,7 @@
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0
--expert-distribution-recorder-mode stat
\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
...
...
@@ -415,19 +415,7 @@
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
" print_highlight(\"|----------|-----------|--------|\")\n",
" next(f)\n",
" for i, line in enumerate(f):\n",
" if i < 9:\n",
" layer_id, expert_id, count = line.strip().split(\",\")\n",
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
"print_highlight(response)"
]
},
{
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
f0653886
import
logging
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
DeepEPMode
,
load_json_config
...
...
@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
config
=
_DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert_list
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
recv_x
,
recv_topk_idx
,
...
...
@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
):
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
get_global_expert_distribution_recorder
().
on_deepep_dispatch_low_latency
(
masked_m
)
reorder_topk_ids
=
seg_indptr
=
None
return
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
f0653886
...
...
@@ -18,7 +18,10 @@ from typing import Callable, Optional
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
get_compiler_backend
,
is_cuda
,
is_hip
...
...
@@ -31,8 +34,6 @@ if _is_cuda:
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
expert_distribution_recorder
=
ExpertDistributionRecorder
()
def
fused_topk_native
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -353,6 +354,6 @@ def select_experts(
renormalize
=
renormalize
,
)
expert_distribution_recorder
.
record_new_token
(
topk_ids
)
get_global_
expert_distribution_recorder
().
on_select_experts
(
topk_ids
=
topk_ids
)
return
topk_weights
,
topk_ids
python/sglang/srt/managers/expert_distribution.py
View file @
f0653886
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/expert_location.py
0 → 100644
View file @
f0653886
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
json
import
logging
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Optional
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
ExpertLocationMetadata
:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@
property
def
num_layers
(
self
)
->
int
:
return
self
.
physical_to_logical_map
.
shape
[
0
]
@
property
def
num_physical_experts
(
self
)
->
int
:
return
self
.
physical_to_logical_map
.
shape
[
1
]
@
property
def
num_local_physical_experts
(
self
)
->
int
:
ans
,
remainder
=
divmod
(
self
.
num_physical_experts
,
self
.
ep_size
)
assert
remainder
==
0
return
ans
@
property
def
num_logical_experts
(
self
)
->
int
:
return
self
.
logical_to_all_physical_map
.
shape
[
1
]
@
property
def
ep_size
(
self
):
# TODO change when EP size != world size
return
torch
.
distributed
.
get_world_size
()
def
__post_init__
(
self
):
num_layers_0
,
num_physical_experts_0
=
self
.
physical_to_logical_map
.
shape
num_layers_1
,
num_logical_experts_0
,
num_physical_experts_1
=
(
self
.
logical_to_all_physical_map
.
shape
)
num_layers_2
,
num_logical_experts_1
=
(
self
.
logical_to_all_physical_map_num_valid
.
shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert
num_physical_experts_0
==
num_physical_experts_1
# -------------------------------- construction ------------------------------------
@
staticmethod
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Trivial location - logical expert i corresponds to physical expert i"""
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
num_physical_experts
=
common
[
"num_physical_experts"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_layers
=
model_config_for_expert_location
.
num_layers
num_logical_experts
=
model_config_for_expert_location
.
num_logical_experts
physical_to_logical_map
=
(
torch
.
arange
(
0
,
num_physical_experts
).
repeat
(
num_layers
,
1
)
%
num_logical_experts
)
return
ExpertLocationMetadata
.
init_by_mapping
(
server_args
,
model_config
,
physical_to_logical_map
=
physical_to_logical_map
,
)
@
staticmethod
def
init_by_mapping
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
,
physical_to_logical_map
,
):
if
not
isinstance
(
physical_to_logical_map
,
torch
.
Tensor
):
physical_to_logical_map
=
torch
.
tensor
(
physical_to_logical_map
)
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
physical_to_logical_map
,
num_logical_experts
=
model_config_for_expert_location
.
num_logical_experts
,
)
return
ExpertLocationMetadata
.
_init_raw
(
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
)
@
staticmethod
def
_init_common
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
model_config_for_expert_location
=
(
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
)
num_physical_experts
=
(
model_config_for_expert_location
.
num_logical_experts
# TODO pr-chain: enable this later
# + server_args.ep_num_redundant_experts
)
ep_size
=
server_args
.
ep_size
assert
num_physical_experts
%
ep_size
==
0
num_local_physical_experts
=
num_physical_experts
//
ep_size
return
dict
(
model_config_for_expert_location
=
model_config_for_expert_location
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
ep_size
=
ep_size
,
)
@
staticmethod
def
_init_raw
(
ep_size
:
int
,
physical_to_logical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
):
_
,
num_physical_experts
=
physical_to_logical_map
.
shape
logical_to_all_physical_map_padded
=
F
.
pad
(
logical_to_all_physical_map
,
(
0
,
num_physical_experts
-
logical_to_all_physical_map
.
shape
[
-
1
]),
value
=-
1
,
)
logical_to_all_physical_map_num_valid
=
torch
.
count_nonzero
(
logical_to_all_physical_map
!=
-
1
,
dim
=-
1
)
return
ExpertLocationMetadata
(
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
)
_global_expert_location_metadata
:
Optional
[
ExpertLocationMetadata
]
=
None
def
get_global_expert_location_metadata
():
return
_global_expert_location_metadata
def
set_global_expert_location_metadata
(
value
):
global
_global_expert_location_metadata
assert
_global_expert_location_metadata
is
None
_global_expert_location_metadata
=
value
def
_compute_logical_to_all_physical_map
(
physical_to_logical_map
:
torch
.
Tensor
,
num_logical_experts
:
int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers
,
num_physical_experts
=
physical_to_logical_map
.
shape
logical_to_all_physical_map
=
[
[[]
for
_
in
range
(
num_logical_experts
)]
for
_
in
range
(
num_layers
)
]
for
layer_id
in
range
(
num_layers
):
for
physical_expert_id
in
range
(
num_physical_experts
):
logical_expert_id
=
physical_to_logical_map
[
layer_id
,
physical_expert_id
].
item
()
logical_to_all_physical_map
[
layer_id
][
logical_expert_id
].
append
(
physical_expert_id
)
logical_to_all_physical_map
=
_pad_nested_array
(
logical_to_all_physical_map
,
pad_value
=-
1
)
return
torch
.
tensor
(
logical_to_all_physical_map
,
device
=
physical_to_logical_map
.
device
)
def
_pad_nested_array
(
arr
,
pad_value
):
max_len
=
max
(
len
(
inner
)
for
outer
in
arr
for
inner
in
outer
)
padded
=
[
[
inner
+
[
pad_value
]
*
(
max_len
-
len
(
inner
))
for
inner
in
outer
]
for
outer
in
arr
]
return
padded
@
dataclass
class
ModelConfigForExpertLocation
:
num_layers
:
int
num_logical_experts
:
int
num_groups
:
Optional
[
int
]
=
None
@
staticmethod
def
init_dummy
():
return
ModelConfigForExpertLocation
(
num_layers
=
1
,
num_logical_experts
=
1
)
@
staticmethod
def
from_model_config
(
model_config
:
ModelConfig
):
model_class
,
_
=
get_model_architecture
(
model_config
)
if
hasattr
(
model_class
,
"get_model_config_for_expert_location"
):
return
model_class
.
get_model_config_for_expert_location
(
model_config
.
hf_config
)
else
:
return
ModelConfigForExpertLocation
.
init_dummy
()
def
compute_initial_expert_location_metadata
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
)
->
ExpertLocationMetadata
:
data
=
server_args
.
init_expert_location
if
data
==
"trivial"
:
logger
.
info
(
"init_expert_location from trivial"
)
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
# TODO unify with the utils function
if
data
.
endswith
(
".pt"
):
data_dict
=
torch
.
load
(
data
,
weights_only
=
True
)
elif
data
.
endswith
(
".json"
):
data_dict
=
json
.
loads
(
Path
(
data
).
read_text
())
else
:
data_dict
=
json
.
loads
(
data
)
if
"physical_to_logical_map"
in
data_dict
:
logger
.
info
(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return
ExpertLocationMetadata
.
init_by_mapping
(
server_args
,
model_config
,
**
data_dict
)
elif
"logical_count"
in
data_dict
:
# TODO pr-chain: enable this later
raise
NotImplementedError
# logger.info(
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# )
# return ExpertLocationMetadata.init_by_eplb(
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else
:
raise
NotImplementedError
(
f
"Unknown init_expert_location format (
{
list
(
data_dict
.
keys
())
=
}
)"
)
python/sglang/srt/managers/scheduler.py
View file @
f0653886
...
...
@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
CloseSessionReqInput
,
...
...
@@ -142,8 +145,6 @@ from sglang.srt.utils import (
)
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
# Test retract decode for debugging purposes
...
...
@@ -2162,11 +2163,11 @@ class Scheduler(
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
expert_distribution_recorder
.
start_record
()
get_global_
expert_distribution_recorder
()
.
start_record
()
elif
recv_req
==
ExpertDistributionReq
.
STOP_RECORD
:
expert_distribution_recorder
.
stop_record
()
get_global_
expert_distribution_recorder
()
.
stop_record
()
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
expert_distribution_recorder
.
dump_record
()
get_global_
expert_distribution_recorder
()
.
dump_record
()
else
:
raise
ValueError
(
"Unrecognized ExpertDistributionReq value"
)
return
ExpertDistributionReqOutput
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f0653886
...
...
@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
set_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
(
compute_initial_expert_location_metadata
,
get_global_expert_location_metadata
,
set_global_expert_location_metadata
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
...
...
@@ -161,6 +171,8 @@ class ModelRunner:
self
.
use_mla_backend
=
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
forward_pass_id
=
0
# Model-specific adjustment
self
.
model_specific_adjustment
()
...
...
@@ -219,6 +231,25 @@ class ModelRunner:
enable
=
self
.
server_args
.
enable_memory_saver
)
if
not
self
.
is_draft_worker
:
set_global_expert_location_metadata
(
compute_initial_expert_location_metadata
(
server_args
,
self
.
model_config
)
)
if
self
.
tp_rank
==
0
and
get_bool_env_var
(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger
.
info
(
f
"Initial expert_location_metadata:
{
get_global_expert_location_metadata
().
debug_str
()
}
"
)
set_global_expert_distribution_recorder
(
ExpertDistributionRecorder
.
init_new
(
server_args
,
get_global_expert_location_metadata
(),
rank
=
self
.
tp_rank
,
)
)
# Load the model
self
.
sampler
=
Sampler
()
self
.
load_model
()
...
...
@@ -1093,6 +1124,22 @@ class ModelRunner:
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
self
.
forward_pass_id
+=
1
with
get_global_expert_distribution_recorder
().
with_forward_pass
(
self
.
forward_pass_id
,
forward_batch
,
):
return
self
.
_forward_raw
(
forward_batch
,
skip_attn_backend_init
,
pp_proxy_tensors
)
def
_forward_raw
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
],
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f0653886
...
...
@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope
,
)
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
forward_mode
=
forward_batch
.
forward_mode
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
...
...
@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
...
...
@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
expert_distribution_recorder
.
set
_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
with
get_global_
expert_distribution_recorder
().
with
_current_layer
(
i
)
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
n_routed_experts
,
num_groups
=
config
.
n_group
,
)
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
f0653886
...
...
@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
expert_distribution_recorder
.
set
_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
with
get_global_
expert_distribution_recorder
().
with
_current_layer
(
i
)
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
...
...
@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
num_experts
,
num_groups
=
None
,
)
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/server_args.py
View file @
f0653886
...
...
@@ -170,6 +170,11 @@ class ServerArgs:
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
init_expert_location
:
str
=
"trivial"
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
deepep_config
:
Optional
[
str
]
=
None
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
...
...
@@ -361,6 +366,15 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule."
)
if
self
.
expert_distribution_recorder_buffer_size
is
None
:
# TODO pr-chain: enable this later
# if (x := self.eplb_rebalance_num_iterations) is not None:
# self.expert_distribution_recorder_buffer_size = x
if
False
:
pass
elif
self
.
expert_distribution_recorder_mode
is
not
None
:
self
.
expert_distribution_recorder_buffer_size
=
1000
# Speculative Decoding
if
self
.
speculative_algorithm
==
"NEXTN"
:
# NEXTN shares the same implementation of EAGLE
...
...
@@ -1257,6 +1271,24 @@ class ServerArgs:
default
=
"auto"
,
help
=
"Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch."
,
)
parser
.
add_argument
(
"--init-expert-location"
,
type
=
str
,
default
=
ServerArgs
.
init_expert_location
,
help
=
"Initial location of EP experts."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
type
=
str
,
default
=
ServerArgs
.
expert_distribution_recorder_mode
,
help
=
"Mode of expert distribution recorder."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-buffer-size"
,
type
=
int
,
default
=
ServerArgs
.
expert_distribution_recorder_buffer_size
,
help
=
"Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer."
,
)
parser
.
add_argument
(
"--deepep-config"
,
type
=
str
,
...
...
python/sglang/srt/utils.py
View file @
f0653886
...
...
@@ -46,7 +46,19 @@ from importlib.util import find_spec
from
io
import
BytesIO
from
multiprocessing.reduction
import
ForkingPickler
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
TypeVar
,
Union
,
)
import
numpy
as
np
import
psutil
...
...
@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
def
dispose_tensor
(
x
:
torch
.
Tensor
):
x
.
set_
(
torch
.
empty
((
0
,),
device
=
x
.
device
,
dtype
=
x
.
dtype
))
T
=
TypeVar
(
"T"
)
class
Withable
(
Generic
[
T
]):
def
__init__
(
self
):
self
.
_value
:
Optional
[
T
]
=
None
@
property
def
value
(
self
)
->
T
:
return
self
.
_value
@
contextmanager
def
with_value
(
self
,
new_value
:
T
):
assert
self
.
_value
is
None
self
.
_value
=
new_value
try
:
yield
finally
:
assert
self
.
_value
is
new_value
self
.
_value
=
None
test/srt/test_expert_distribution.py
View file @
f0653886
import
csv
import
glob
import
os
import
tempfile
import
unittest
from
pathlib
import
Path
import
requests
import
torch
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
...
...
@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
class
TestExpertDistribution
(
CustomTestCase
):
def
setUp
(
self
):
# Clean up any existing expert distribution files before each test
for
f
in
glob
.
glob
(
"expert_distribution_*.csv"
):
os
.
remove
(
f
)
def
tearDown
(
self
):
# Clean up any expert distribution files after each test
for
f
in
glob
.
glob
(
"expert_distribution_*.csv"
):
os
.
remove
(
f
)
def
test_expert_distribution_record
(
self
):
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
for
info
in
[
dict
(
model_path
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
,
tp_size
=
2
),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with
self
.
subTest
(
info
=
info
):
self
.
_execute_core
(
**
info
)
def
_execute_core
(
self
,
model_path
:
str
,
mode
:
str
=
"stat"
,
tp_size
:
int
=
1
):
"""Test expert distribution record endpoints"""
process
=
popen_launch_server
(
# The feature is only implemented in deepseek_v2.py
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
],
)
try
:
# Start recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/start_expert_distribution_record"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
os
.
environ
[
"SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"
]
=
tmp_dir
process
=
popen_launch_server
(
model_path
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp-size"
,
str
(
tp_size
),
"--expert-distribution-recorder-mode"
,
mode
,
"--disable-cuda-graph"
,
"--disable-overlap-schedule"
,
],
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Make some requests to generate expert distribution data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
try
:
# Start recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/start_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Make some requests to generate expert distribution data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Stop recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/stop_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Dump the recorded data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/dump_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Verify the dumped file exists and has correct format
csv_files
=
glob
.
glob
(
"expert_distribution_*.csv"
)
self
.
assertEqual
(
len
(
csv_files
),
1
,
f
"Expected exactly one expert distribution CSV file
{
csv_files
=
}
"
,
)
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check CSV file format
with
open
(
csv_files
[
0
],
"r"
)
as
f
:
csv_reader
=
csv
.
reader
(
f
)
# Stop recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/stop_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check header
header
=
next
(
csv_reader
)
self
.
assertEqual
(
header
,
[
"layer_id"
,
"expert_id"
,
"count"
],
"CSV header should be 'layer_id,expert_id,count'"
,
# Dump the recorded data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/dump_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check data rows
rows
=
list
(
csv_reader
)
self
.
assertGreater
(
len
(
rows
),
0
,
"CSV file should contain data rows"
)
for
row
in
rows
:
# Verify each row has 3 columns
self
.
assertEqual
(
len
(
row
),
3
,
"Each row should have layer_id, expert_id and count"
,
)
data
=
torch
.
load
(
list
(
Path
(
tmp_dir
).
glob
(
"*.pt"
))[
0
],
weights_only
=
True
)
print
(
f
"
{
data
=
}
"
)
# Verify data types
layer_id
,
expert_id
,
count
=
row
self
.
assertTrue
(
layer_id
.
isdigit
(),
f
"layer_id should be an integer
{
row
=
}
{
rows
=
}
"
,
)
self
.
assertTrue
(
expert_id
.
isdigit
(),
f
"expert_id should be an integer
{
row
=
}
{
rows
=
}
"
,
)
self
.
assertTrue
(
count
.
isdigit
(),
f
"count should be an integer
{
row
=
}
{
rows
=
}
"
)
if
mode
in
[
"per_pass"
,
"per_token"
]:
self
.
assertGreater
(
len
(
data
),
0
,
"Should contain data rows"
)
else
:
logical_count
=
data
[
"logical_count"
]
print
(
f
"
{
logical_count
.
sum
()
=
}
{
logical_count
=
}
"
)
self
.
assertTrue
(
logical_count
.
sum
()
>
0
)
finally
:
kill_process_tree
(
process
.
pid
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment