Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
f0653886
"vscode:/vscode.git/clone" did not exist on "c57b4328b070b5ceeec42cda3a1e4b34f397eefb"
Unverified
Commit
f0653886
authored
May 20, 2025
by
fzyzcjy
Committed by
GitHub
May 19, 2025
Browse files
Expert distribution recording without overhead for EPLB (#4957)
parent
b1465557
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1119 additions
and
190 deletions
+1119
-190
docs/backend/native_api.ipynb
docs/backend/native_api.ipynb
+2
-14
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+14
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+5
-4
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+595
-56
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+273
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+47
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+20
-9
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+18
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+32
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+35
-1
test/srt/test_expert_distribution.py
test/srt/test_expert_distribution.py
+71
-92
No files found.
docs/backend/native_api.ipynb
View file @
f0653886
...
...
@@ -390,7 +390,7 @@
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0
--expert-distribution-recorder-mode stat
\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
...
...
@@ -415,19 +415,7 @@
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
" print_highlight(\"|----------|-----------|--------|\")\n",
" next(f)\n",
" for i, line in enumerate(f):\n",
" if i < 9:\n",
" layer_id, expert_id, count = line.strip().split(\",\")\n",
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
"print_highlight(response)"
]
},
{
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
f0653886
import
logging
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.managers.expert_distribution
import
(
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
DeepEPMode
,
load_json_config
...
...
@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
config
=
_DeepEPConfig
.
get_instance
().
normal_dispatch_config
,
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert_list
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
)
return
(
recv_x
,
recv_topk_idx
,
...
...
@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
):
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
get_global_expert_distribution_recorder
().
on_deepep_dispatch_low_latency
(
masked_m
)
reorder_topk_ids
=
seg_indptr
=
None
return
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
f0653886
...
...
@@ -18,7 +18,10 @@ from typing import Callable, Optional
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
get_compiler_backend
,
is_cuda
,
is_hip
...
...
@@ -31,8 +34,6 @@ if _is_cuda:
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
topk_softmax
expert_distribution_recorder
=
ExpertDistributionRecorder
()
def
fused_topk_native
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -353,6 +354,6 @@ def select_experts(
renormalize
=
renormalize
,
)
expert_distribution_recorder
.
record_new_token
(
topk_ids
)
get_global_
expert_distribution_recorder
().
on_select_experts
(
topk_ids
=
topk_ids
)
return
topk_weights
,
topk_ids
python/sglang/srt/managers/expert_distribution.py
View file @
f0653886
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/expert_location.py
0 → 100644
View file @
f0653886
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
json
import
logging
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Optional
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.server_args
import
ServerArgs
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
ExpertLocationMetadata
:
physical_to_logical_map
:
torch
.
Tensor
# (layers, num_physical_experts)
logical_to_all_physical_map
:
torch
.
Tensor
# (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
# (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@
property
def
num_layers
(
self
)
->
int
:
return
self
.
physical_to_logical_map
.
shape
[
0
]
@
property
def
num_physical_experts
(
self
)
->
int
:
return
self
.
physical_to_logical_map
.
shape
[
1
]
@
property
def
num_local_physical_experts
(
self
)
->
int
:
ans
,
remainder
=
divmod
(
self
.
num_physical_experts
,
self
.
ep_size
)
assert
remainder
==
0
return
ans
@
property
def
num_logical_experts
(
self
)
->
int
:
return
self
.
logical_to_all_physical_map
.
shape
[
1
]
@
property
def
ep_size
(
self
):
# TODO change when EP size != world size
return
torch
.
distributed
.
get_world_size
()
def
__post_init__
(
self
):
num_layers_0
,
num_physical_experts_0
=
self
.
physical_to_logical_map
.
shape
num_layers_1
,
num_logical_experts_0
,
num_physical_experts_1
=
(
self
.
logical_to_all_physical_map
.
shape
)
num_layers_2
,
num_logical_experts_1
=
(
self
.
logical_to_all_physical_map_num_valid
.
shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert
num_physical_experts_0
==
num_physical_experts_1
# -------------------------------- construction ------------------------------------
@
staticmethod
def
init_trivial
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
"""Trivial location - logical expert i corresponds to physical expert i"""
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
num_physical_experts
=
common
[
"num_physical_experts"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_layers
=
model_config_for_expert_location
.
num_layers
num_logical_experts
=
model_config_for_expert_location
.
num_logical_experts
physical_to_logical_map
=
(
torch
.
arange
(
0
,
num_physical_experts
).
repeat
(
num_layers
,
1
)
%
num_logical_experts
)
return
ExpertLocationMetadata
.
init_by_mapping
(
server_args
,
model_config
,
physical_to_logical_map
=
physical_to_logical_map
,
)
@
staticmethod
def
init_by_mapping
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
,
physical_to_logical_map
,
):
if
not
isinstance
(
physical_to_logical_map
,
torch
.
Tensor
):
physical_to_logical_map
=
torch
.
tensor
(
physical_to_logical_map
)
physical_to_logical_map
=
physical_to_logical_map
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
logical_to_all_physical_map
=
_compute_logical_to_all_physical_map
(
physical_to_logical_map
,
num_logical_experts
=
model_config_for_expert_location
.
num_logical_experts
,
)
return
ExpertLocationMetadata
.
_init_raw
(
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
)
@
staticmethod
def
_init_common
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
model_config_for_expert_location
=
(
ModelConfigForExpertLocation
.
from_model_config
(
model_config
)
)
num_physical_experts
=
(
model_config_for_expert_location
.
num_logical_experts
# TODO pr-chain: enable this later
# + server_args.ep_num_redundant_experts
)
ep_size
=
server_args
.
ep_size
assert
num_physical_experts
%
ep_size
==
0
num_local_physical_experts
=
num_physical_experts
//
ep_size
return
dict
(
model_config_for_expert_location
=
model_config_for_expert_location
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
ep_size
=
ep_size
,
)
@
staticmethod
def
_init_raw
(
ep_size
:
int
,
physical_to_logical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
):
_
,
num_physical_experts
=
physical_to_logical_map
.
shape
logical_to_all_physical_map_padded
=
F
.
pad
(
logical_to_all_physical_map
,
(
0
,
num_physical_experts
-
logical_to_all_physical_map
.
shape
[
-
1
]),
value
=-
1
,
)
logical_to_all_physical_map_num_valid
=
torch
.
count_nonzero
(
logical_to_all_physical_map
!=
-
1
,
dim
=-
1
)
return
ExpertLocationMetadata
(
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map_padded
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
)
_global_expert_location_metadata
:
Optional
[
ExpertLocationMetadata
]
=
None
def
get_global_expert_location_metadata
():
return
_global_expert_location_metadata
def
set_global_expert_location_metadata
(
value
):
global
_global_expert_location_metadata
assert
_global_expert_location_metadata
is
None
_global_expert_location_metadata
=
value
def
_compute_logical_to_all_physical_map
(
physical_to_logical_map
:
torch
.
Tensor
,
num_logical_experts
:
int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers
,
num_physical_experts
=
physical_to_logical_map
.
shape
logical_to_all_physical_map
=
[
[[]
for
_
in
range
(
num_logical_experts
)]
for
_
in
range
(
num_layers
)
]
for
layer_id
in
range
(
num_layers
):
for
physical_expert_id
in
range
(
num_physical_experts
):
logical_expert_id
=
physical_to_logical_map
[
layer_id
,
physical_expert_id
].
item
()
logical_to_all_physical_map
[
layer_id
][
logical_expert_id
].
append
(
physical_expert_id
)
logical_to_all_physical_map
=
_pad_nested_array
(
logical_to_all_physical_map
,
pad_value
=-
1
)
return
torch
.
tensor
(
logical_to_all_physical_map
,
device
=
physical_to_logical_map
.
device
)
def
_pad_nested_array
(
arr
,
pad_value
):
max_len
=
max
(
len
(
inner
)
for
outer
in
arr
for
inner
in
outer
)
padded
=
[
[
inner
+
[
pad_value
]
*
(
max_len
-
len
(
inner
))
for
inner
in
outer
]
for
outer
in
arr
]
return
padded
@
dataclass
class
ModelConfigForExpertLocation
:
num_layers
:
int
num_logical_experts
:
int
num_groups
:
Optional
[
int
]
=
None
@
staticmethod
def
init_dummy
():
return
ModelConfigForExpertLocation
(
num_layers
=
1
,
num_logical_experts
=
1
)
@
staticmethod
def
from_model_config
(
model_config
:
ModelConfig
):
model_class
,
_
=
get_model_architecture
(
model_config
)
if
hasattr
(
model_class
,
"get_model_config_for_expert_location"
):
return
model_class
.
get_model_config_for_expert_location
(
model_config
.
hf_config
)
else
:
return
ModelConfigForExpertLocation
.
init_dummy
()
def
compute_initial_expert_location_metadata
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
)
->
ExpertLocationMetadata
:
data
=
server_args
.
init_expert_location
if
data
==
"trivial"
:
logger
.
info
(
"init_expert_location from trivial"
)
return
ExpertLocationMetadata
.
init_trivial
(
server_args
,
model_config
)
# TODO unify with the utils function
if
data
.
endswith
(
".pt"
):
data_dict
=
torch
.
load
(
data
,
weights_only
=
True
)
elif
data
.
endswith
(
".json"
):
data_dict
=
json
.
loads
(
Path
(
data
).
read_text
())
else
:
data_dict
=
json
.
loads
(
data
)
if
"physical_to_logical_map"
in
data_dict
:
logger
.
info
(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return
ExpertLocationMetadata
.
init_by_mapping
(
server_args
,
model_config
,
**
data_dict
)
elif
"logical_count"
in
data_dict
:
# TODO pr-chain: enable this later
raise
NotImplementedError
# logger.info(
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# )
# return ExpertLocationMetadata.init_by_eplb(
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else
:
raise
NotImplementedError
(
f
"Unknown init_expert_location format (
{
list
(
data_dict
.
keys
())
=
}
)"
)
python/sglang/srt/managers/scheduler.py
View file @
f0653886
...
...
@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
CloseSessionReqInput
,
...
...
@@ -142,8 +145,6 @@ from sglang.srt.utils import (
)
from
sglang.utils
import
TypeBasedDispatcher
,
get_exception_traceback
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
# Test retract decode for debugging purposes
...
...
@@ -2162,11 +2163,11 @@ class Scheduler(
def
expert_distribution_handle
(
self
,
recv_req
:
ExpertDistributionReq
):
if
recv_req
==
ExpertDistributionReq
.
START_RECORD
:
expert_distribution_recorder
.
start_record
()
get_global_
expert_distribution_recorder
()
.
start_record
()
elif
recv_req
==
ExpertDistributionReq
.
STOP_RECORD
:
expert_distribution_recorder
.
stop_record
()
get_global_
expert_distribution_recorder
()
.
stop_record
()
elif
recv_req
==
ExpertDistributionReq
.
DUMP_RECORD
:
expert_distribution_recorder
.
dump_record
()
get_global_
expert_distribution_recorder
()
.
dump_record
()
else
:
raise
ValueError
(
"Unrecognized ExpertDistributionReq value"
)
return
ExpertDistributionReqOutput
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f0653886
...
...
@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
set_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
(
compute_initial_expert_location_metadata
,
get_global_expert_location_metadata
,
set_global_expert_location_metadata
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
...
...
@@ -161,6 +171,8 @@ class ModelRunner:
self
.
use_mla_backend
=
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
forward_pass_id
=
0
# Model-specific adjustment
self
.
model_specific_adjustment
()
...
...
@@ -219,6 +231,25 @@ class ModelRunner:
enable
=
self
.
server_args
.
enable_memory_saver
)
if
not
self
.
is_draft_worker
:
set_global_expert_location_metadata
(
compute_initial_expert_location_metadata
(
server_args
,
self
.
model_config
)
)
if
self
.
tp_rank
==
0
and
get_bool_env_var
(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger
.
info
(
f
"Initial expert_location_metadata:
{
get_global_expert_location_metadata
().
debug_str
()
}
"
)
set_global_expert_distribution_recorder
(
ExpertDistributionRecorder
.
init_new
(
server_args
,
get_global_expert_location_metadata
(),
rank
=
self
.
tp_rank
,
)
)
# Load the model
self
.
sampler
=
Sampler
()
self
.
load_model
()
...
...
@@ -1093,6 +1124,22 @@ class ModelRunner:
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
self
.
forward_pass_id
+=
1
with
get_global_expert_distribution_recorder
().
with_forward_pass
(
self
.
forward_pass_id
,
forward_batch
,
):
return
self
.
_forward_raw
(
forward_batch
,
skip_attn_backend_init
,
pp_proxy_tensors
)
def
_forward_raw
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
],
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
PPProxyTensors
],
bool
]:
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f0653886
...
...
@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope
,
)
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
forward_mode
=
forward_batch
.
forward_mode
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
forward_mode
,
hidden_states
):
...
...
@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
...
...
@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
expert_distribution_recorder
.
set
_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
with
get_global_
expert_distribution_recorder
().
with
_current_layer
(
i
)
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
...
...
@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
n_routed_experts
,
num_groups
=
config
.
n_group
,
)
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
f0653886
...
...
@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
(
ExpertDistributionRecorder
,
get_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
expert_distribution_recorder
.
set
_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
with
get_global_
expert_distribution_recorder
().
with
_current_layer
(
i
)
:
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
...
...
@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
num_experts
,
num_groups
=
None
,
)
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/server_args.py
View file @
f0653886
...
...
@@ -170,6 +170,11 @@ class ServerArgs:
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
init_expert_location
:
str
=
"trivial"
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
]
=
None
expert_distribution_recorder_buffer_size
:
Optional
[
int
]
=
None
deepep_config
:
Optional
[
str
]
=
None
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
...
...
@@ -361,6 +366,15 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule."
)
if
self
.
expert_distribution_recorder_buffer_size
is
None
:
# TODO pr-chain: enable this later
# if (x := self.eplb_rebalance_num_iterations) is not None:
# self.expert_distribution_recorder_buffer_size = x
if
False
:
pass
elif
self
.
expert_distribution_recorder_mode
is
not
None
:
self
.
expert_distribution_recorder_buffer_size
=
1000
# Speculative Decoding
if
self
.
speculative_algorithm
==
"NEXTN"
:
# NEXTN shares the same implementation of EAGLE
...
...
@@ -1257,6 +1271,24 @@ class ServerArgs:
default
=
"auto"
,
help
=
"Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch."
,
)
parser
.
add_argument
(
"--init-expert-location"
,
type
=
str
,
default
=
ServerArgs
.
init_expert_location
,
help
=
"Initial location of EP experts."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-mode"
,
type
=
str
,
default
=
ServerArgs
.
expert_distribution_recorder_mode
,
help
=
"Mode of expert distribution recorder."
,
)
parser
.
add_argument
(
"--expert-distribution-recorder-buffer-size"
,
type
=
int
,
default
=
ServerArgs
.
expert_distribution_recorder_buffer_size
,
help
=
"Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer."
,
)
parser
.
add_argument
(
"--deepep-config"
,
type
=
str
,
...
...
python/sglang/srt/utils.py
View file @
f0653886
...
...
@@ -46,7 +46,19 @@ from importlib.util import find_spec
from
io
import
BytesIO
from
multiprocessing.reduction
import
ForkingPickler
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
TypeVar
,
Union
,
)
import
numpy
as
np
import
psutil
...
...
@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
def
dispose_tensor
(
x
:
torch
.
Tensor
):
x
.
set_
(
torch
.
empty
((
0
,),
device
=
x
.
device
,
dtype
=
x
.
dtype
))
T
=
TypeVar
(
"T"
)
class
Withable
(
Generic
[
T
]):
def
__init__
(
self
):
self
.
_value
:
Optional
[
T
]
=
None
@
property
def
value
(
self
)
->
T
:
return
self
.
_value
@
contextmanager
def
with_value
(
self
,
new_value
:
T
):
assert
self
.
_value
is
None
self
.
_value
=
new_value
try
:
yield
finally
:
assert
self
.
_value
is
new_value
self
.
_value
=
None
test/srt/test_expert_distribution.py
View file @
f0653886
import
csv
import
glob
import
os
import
tempfile
import
unittest
from
pathlib
import
Path
import
requests
import
torch
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
...
...
@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
class
TestExpertDistribution
(
CustomTestCase
):
def
setUp
(
self
):
# Clean up any existing expert distribution files before each test
for
f
in
glob
.
glob
(
"expert_distribution_*.csv"
):
os
.
remove
(
f
)
def
tearDown
(
self
):
# Clean up any expert distribution files after each test
for
f
in
glob
.
glob
(
"expert_distribution_*.csv"
):
os
.
remove
(
f
)
def
test_expert_distribution_record
(
self
):
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
for
info
in
[
dict
(
model_path
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
,
tp_size
=
2
),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with
self
.
subTest
(
info
=
info
):
self
.
_execute_core
(
**
info
)
def
_execute_core
(
self
,
model_path
:
str
,
mode
:
str
=
"stat"
,
tp_size
:
int
=
1
):
"""Test expert distribution record endpoints"""
process
=
popen_launch_server
(
# The feature is only implemented in deepseek_v2.py
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
],
)
try
:
# Start recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/start_expert_distribution_record"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
os
.
environ
[
"SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"
]
=
tmp_dir
process
=
popen_launch_server
(
model_path
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp-size"
,
str
(
tp_size
),
"--expert-distribution-recorder-mode"
,
mode
,
"--disable-cuda-graph"
,
"--disable-overlap-schedule"
,
],
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Make some requests to generate expert distribution data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
try
:
# Start recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/start_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Make some requests to generate expert distribution data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Stop recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/stop_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Dump the recorded data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/dump_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Verify the dumped file exists and has correct format
csv_files
=
glob
.
glob
(
"expert_distribution_*.csv"
)
self
.
assertEqual
(
len
(
csv_files
),
1
,
f
"Expected exactly one expert distribution CSV file
{
csv_files
=
}
"
,
)
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check CSV file format
with
open
(
csv_files
[
0
],
"r"
)
as
f
:
csv_reader
=
csv
.
reader
(
f
)
# Stop recording
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/stop_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check header
header
=
next
(
csv_reader
)
self
.
assertEqual
(
header
,
[
"layer_id"
,
"expert_id"
,
"count"
],
"CSV header should be 'layer_id,expert_id,count'"
,
# Dump the recorded data
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/dump_expert_distribution_record"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Check data rows
rows
=
list
(
csv_reader
)
self
.
assertGreater
(
len
(
rows
),
0
,
"CSV file should contain data rows"
)
for
row
in
rows
:
# Verify each row has 3 columns
self
.
assertEqual
(
len
(
row
),
3
,
"Each row should have layer_id, expert_id and count"
,
)
data
=
torch
.
load
(
list
(
Path
(
tmp_dir
).
glob
(
"*.pt"
))[
0
],
weights_only
=
True
)
print
(
f
"
{
data
=
}
"
)
# Verify data types
layer_id
,
expert_id
,
count
=
row
self
.
assertTrue
(
layer_id
.
isdigit
(),
f
"layer_id should be an integer
{
row
=
}
{
rows
=
}
"
,
)
self
.
assertTrue
(
expert_id
.
isdigit
(),
f
"expert_id should be an integer
{
row
=
}
{
rows
=
}
"
,
)
self
.
assertTrue
(
count
.
isdigit
(),
f
"count should be an integer
{
row
=
}
{
rows
=
}
"
)
if
mode
in
[
"per_pass"
,
"per_token"
]:
self
.
assertGreater
(
len
(
data
),
0
,
"Should contain data rows"
)
else
:
logical_count
=
data
[
"logical_count"
]
print
(
f
"
{
logical_count
.
sum
()
=
}
{
logical_count
=
}
"
)
self
.
assertTrue
(
logical_count
.
sum
()
>
0
)
finally
:
kill_process_tree
(
process
.
pid
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment