Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c3b71d6
Unverified
Commit
2c3b71d6
authored
May 30, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Improve EPLB logical to physical dispatch map (#6727)
parent
51cdd81f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
32 deletions
+66
-32
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+66
-32
No files found.
python/sglang/srt/managers/expert_location.py
View file @
2c3b71d6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# ==============================================================================
# ==============================================================================
import
json
import
json
import
logging
import
logging
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
...
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
num_gpus
=
ep_size
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
ep_rank
=
torch
.
distributed
.
get_rank
(),
# TODO improve when we have real EP rank
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
),
),
)
)
...
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
...
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
return
padded
return
padded
# TODO
use more sophisticated approaches
# TODO
optimize performance (rewrite and/or run in separate process with overlap)
def
compute_logical_to_rank_dispatch_physical_map
(
def
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
,
num_gpus
:
int
,
num_gpus
:
int
,
num_physical_experts
:
int
,
num_physical_experts
:
int
,
ep_rank
:
int
,
ep_rank
:
int
,
base_
seed
:
int
=
42
,
seed
:
int
=
42
,
):
):
device
=
logical_to_all_physical_map
.
device
r
=
random
.
Random
(
seed
)
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
dtype
=
logical_to_all_physical_map
.
dtype
g
=
torch
.
Generator
(
device
=
device
)
logical_to_rank_dispatch_physical_map
=
torch
.
full
(
g
.
manual_seed
(
base_seed
+
ep_rank
)
size
=
(
num_gpus
,
num_layers
,
num_logical_experts
),
fill_value
=-
1
,
output_shape
=
(
num_layers
,
num_logical_experts
)
dtype
=
dtype
,
chosen_index
=
(
torch
.
randint
(
0
,
65536
,
output_shape
,
dtype
=
torch
.
int32
,
device
=
device
,
generator
=
g
)
%
logical_to_all_physical_map_num_valid
)
)
logical_to_rank_dispatch_physical_map
=
torch
.
gather
(
logical_to_all_physical_map
,
dim
=
2
,
index
=
chosen_index
.
unsqueeze
(
-
1
)
for
layer_id
in
range
(
num_layers
):
).
squeeze
(
-
1
)
for
logical_expert_id
in
range
(
num_logical_experts
):
assert
logical_to_rank_dispatch_physical_map
.
shape
==
output_shape
candidate_physical_expert_ids
=
_logical_to_all_physical_raw
(
logical_to_all_physical_map
,
layer_id
,
logical_expert_id
for
index
in
range
(
logical_to_all_physical_map_num_valid
.
max
().
item
()):
)
partial_logical_to_all_physical_map
=
logical_to_all_physical_map
[:,
:,
index
]
output_partial
=
logical_to_rank_dispatch_physical_map
[
is_valid
=
partial_logical_to_all_physical_map
!=
-
1
:,
layer_id
,
logical_expert_id
is_same_gpu
=
(
]
partial_logical_to_all_physical_map
//
num_local_physical_experts
)
==
ep_rank
for
gpu_id
in
range
(
num_gpus
):
logical_to_rank_dispatch_physical_map
=
torch
.
where
(
same_gpu_physical_expert_ids
=
[
is_valid
&
is_same_gpu
,
physical_expert_id
partial_logical_to_all_physical_map
,
for
physical_expert_id
in
candidate_physical_expert_ids
logical_to_rank_dispatch_physical_map
,
if
_compute_gpu_id_of_physical_expert
(
)
physical_expert_id
,
num_local_physical_experts
)
==
gpu_id
]
if
len
(
same_gpu_physical_expert_ids
)
>
0
:
output_partial
[
gpu_id
]
=
same_gpu_physical_expert_ids
[
0
]
num_remain
=
torch
.
sum
(
output_partial
==
-
1
).
item
()
output_partial
[
output_partial
==
-
1
]
=
torch
.
tensor
(
_fair_choices
(
candidate_physical_expert_ids
,
k
=
num_remain
,
r
=
r
),
dtype
=
dtype
,
)
assert
torch
.
all
(
logical_to_rank_dispatch_physical_map
!=
-
1
)
assert
torch
.
all
(
logical_to_rank_dispatch_physical_map
!=
-
1
)
return
logical_to_rank_dispatch_physical_map
device
=
logical_to_all_physical_map
.
device
return
logical_to_rank_dispatch_physical_map
[
ep_rank
,
:,
:].
to
(
device
)
def
_logical_to_all_physical_raw
(
logical_to_all_physical_map
,
layer_id
:
int
,
logical_expert_id
:
int
)
->
List
[
int
]:
return
[
physical_expert_id
for
physical_expert_id
in
logical_to_all_physical_map
[
layer_id
,
logical_expert_id
].
tolist
()
if
physical_expert_id
!=
-
1
]
def
_compute_gpu_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_physical_experts
:
int
)
->
int
:
return
physical_expert_id
//
num_local_physical_experts
def
_fair_choices
(
arr
:
List
,
k
:
int
,
r
:
random
.
Random
)
->
List
:
quotient
,
remainder
=
divmod
(
k
,
len
(
arr
))
ans
=
arr
*
quotient
+
r
.
sample
(
arr
,
k
=
remainder
)
r
.
shuffle
(
ans
)
return
ans
@
dataclass
@
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment