Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b6d0ce9f
Unverified
Commit
b6d0ce9f
authored
Jun 03, 2025
by
fzyzcjy
Committed by
GitHub
Jun 02, 2025
Browse files
Minor add metrics to expert location updater (#6816)
parent
0ea330ca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
1 deletion
+68
-1
python/sglang/srt/model_executor/expert_location_updater.py
python/sglang/srt/model_executor/expert_location_updater.py
+68
-1
No files found.
python/sglang/srt/model_executor/expert_location_updater.py
View file @
b6d0ce9f
...
...
@@ -12,8 +12,10 @@
# limitations under the License.
# ==============================================================================
import
logging
from
typing
import
Dict
,
List
,
Tuple
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
einops
import
torch
import
torch.distributed
from
torch.distributed
import
P2POp
...
...
@@ -22,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata
,
get_global_expert_location_metadata
,
)
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,6 +62,8 @@ def _update_expert_weights(
nnodes
:
int
,
rank
:
int
,
):
log_metrics
=
get_bool_env_var
(
"SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS"
)
temp_buffers
=
create_temp_buffers
(
next
(
iter
(
routed_experts_weights_of_layer
.
values
()))
)
...
...
@@ -83,6 +88,8 @@ def _update_expert_weights(
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
world_size
=
world_size
,
log_metrics
=
log_metrics
,
)
...
...
@@ -98,7 +105,9 @@ def update_expert_weights_single_layer(
num_local_physical_experts
:
int
,
num_gpu_per_node
:
int
,
rank
:
int
,
world_size
:
Optional
[
int
]
=
None
,
debug
:
bool
=
False
,
log_metrics
:
bool
=
False
,
):
assert
all
(
tensor
.
shape
[
0
]
==
num_local_physical_experts
...
...
@@ -130,6 +139,14 @@ def update_expert_weights_single_layer(
_execute_p2p_ops
(
p2p_op_infos
)
_execute_buffer2weight_copies
(
buffer2weight_copy_infos
)
if
log_metrics
:
_log_p2p_op_metrics
(
p2p_op_infos
,
world_size
=
world_size
,
num_gpu_per_node
=
num_gpu_per_node
,
self_node_id
=
self_node_id
,
)
if
debug
:
output_logs
.
append
(
f
"
{
p2p_op_infos
=
}
"
)
output_logs
.
append
(
f
"
{
buffer2weight_copy_infos
=
}
"
)
...
...
@@ -429,3 +446,53 @@ def _deduplicate_ordered(arr: List[int]):
if
len
(
output
)
==
0
or
item
!=
output
[
-
1
]:
output
.
append
(
item
)
return
output
def
_log_p2p_op_metrics
(
p2p_op_infos
:
List
[
Tuple
[
int
,
List
[
P2POp
]]],
num_gpu_per_node
:
int
,
world_size
:
int
,
self_node_id
:
int
,
):
text
=
""
all_ops
=
[
op
for
_
,
ops
in
p2p_op_infos
for
op
in
ops
]
for
direction
,
ops
in
_group_by
(
all_ops
,
_get_direction_from_op
).
items
():
nbytes_of_gpu
=
[
0
]
*
world_size
for
op
in
ops
:
nbytes_of_gpu
[
op
.
peer
]
+=
op
.
tensor
.
nbytes
nbytes_of_gpu
=
torch
.
tensor
(
nbytes_of_gpu
,
dtype
=
torch
.
int64
)
nbytes_of_node
=
einops
.
reduce
(
nbytes_of_gpu
,
"(num_nodes num_gpu_per_node) -> num_nodes"
,
num_gpu_per_node
=
num_gpu_per_node
,
reduction
=
"sum"
,
)
nbytes_curr_node
=
nbytes_of_node
[
self_node_id
]
nbytes_cross_node
=
torch
.
sum
(
nbytes_of_node
)
-
nbytes_curr_node
text
+=
(
f
"
{
direction
}
_nbytes_of_gpu=
{
nbytes_of_gpu
.
tolist
()
}
"
f
"
{
direction
}
_nbytes_of_node=
{
nbytes_of_node
.
tolist
()
}
"
f
"
{
direction
}
_nbytes_curr_node=
{
nbytes_curr_node
.
item
()
}
"
f
"
{
direction
}
_nbytes_cross_node=
{
nbytes_cross_node
.
item
()
}
"
)
logger
.
info
(
f
"[ExpertLocationUpdater]
{
text
}
"
)
def
_get_direction_from_op
(
op
:
P2POp
):
if
op
.
op
==
torch
.
distributed
.
isend
:
return
"isend"
if
op
.
op
==
torch
.
distributed
.
irecv
:
return
"irecv"
raise
NotImplementedError
def
_group_by
(
items
,
keyfunc
):
ans
=
defaultdict
(
list
)
for
item
in
items
:
ans
[
keyfunc
(
item
)].
append
(
item
)
return
dict
(
ans
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment