Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2c3b71d6
"vscode:/vscode.git/clone" did not exist on "c18f957d0e078f799da5e44e4ac4251cb16b72d4"
Unverified
Commit
2c3b71d6
authored
May 30, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Improve EPLB logical to physical dispatch map (#6727)
parent
51cdd81f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
32 deletions
+66
-32
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+66
-32
No files found.
python/sglang/srt/managers/expert_location.py
View file @
2c3b71d6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# ==============================================================================
# ==============================================================================
import
json
import
json
import
logging
import
logging
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
...
@@ -205,10 +206,10 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_rank_dispatch_physical_map
=
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map_num_valid
=
logical_to_all_physical_map_num_valid
,
num_gpus
=
ep_size
,
num_gpus
=
ep_size
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
ep_rank
=
torch
.
distributed
.
get_rank
(),
# TODO improve when we have real EP rank
ep_rank
=
torch
.
distributed
.
get_rank
()
%
ep_size
,
),
),
)
)
...
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
...
@@ -296,49 +297,82 @@ def _pad_nested_array(arr, pad_value):
return
padded
return
padded
# TODO
use more sophisticated approaches
# TODO
optimize performance (rewrite and/or run in separate process with overlap)
def
compute_logical_to_rank_dispatch_physical_map
(
def
compute_logical_to_rank_dispatch_physical_map
(
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map
:
torch
.
Tensor
,
logical_to_all_physical_map_num_valid
:
torch
.
Tensor
,
num_gpus
:
int
,
num_gpus
:
int
,
num_physical_experts
:
int
,
num_physical_experts
:
int
,
ep_rank
:
int
,
ep_rank
:
int
,
base_
seed
:
int
=
42
,
seed
:
int
=
42
,
):
):
device
=
logical_to_all_physical_map
.
device
r
=
random
.
Random
(
seed
)
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_local_physical_experts
=
num_physical_experts
//
num_gpus
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
num_layers
,
num_logical_experts
,
_
=
logical_to_all_physical_map
.
shape
dtype
=
logical_to_all_physical_map
.
dtype
g
=
torch
.
Generator
(
device
=
device
)
logical_to_rank_dispatch_physical_map
=
torch
.
full
(
g
.
manual_seed
(
base_seed
+
ep_rank
)
size
=
(
num_gpus
,
num_layers
,
num_logical_experts
),
fill_value
=-
1
,
output_shape
=
(
num_layers
,
num_logical_experts
)
dtype
=
dtype
,
chosen_index
=
(
torch
.
randint
(
0
,
65536
,
output_shape
,
dtype
=
torch
.
int32
,
device
=
device
,
generator
=
g
)
%
logical_to_all_physical_map_num_valid
)
)
logical_to_rank_dispatch_physical_map
=
torch
.
gather
(
logical_to_all_physical_map
,
dim
=
2
,
index
=
chosen_index
.
unsqueeze
(
-
1
)
for
layer_id
in
range
(
num_layers
):
).
squeeze
(
-
1
)
for
logical_expert_id
in
range
(
num_logical_experts
):
assert
logical_to_rank_dispatch_physical_map
.
shape
==
output_shape
candidate_physical_expert_ids
=
_logical_to_all_physical_raw
(
logical_to_all_physical_map
,
layer_id
,
logical_expert_id
for
index
in
range
(
logical_to_all_physical_map_num_valid
.
max
().
item
()):
)
partial_logical_to_all_physical_map
=
logical_to_all_physical_map
[:,
:,
index
]
output_partial
=
logical_to_rank_dispatch_physical_map
[
is_valid
=
partial_logical_to_all_physical_map
!=
-
1
:,
layer_id
,
logical_expert_id
is_same_gpu
=
(
]
partial_logical_to_all_physical_map
//
num_local_physical_experts
)
==
ep_rank
for
gpu_id
in
range
(
num_gpus
):
logical_to_rank_dispatch_physical_map
=
torch
.
where
(
same_gpu_physical_expert_ids
=
[
is_valid
&
is_same_gpu
,
physical_expert_id
partial_logical_to_all_physical_map
,
for
physical_expert_id
in
candidate_physical_expert_ids
logical_to_rank_dispatch_physical_map
,
if
_compute_gpu_id_of_physical_expert
(
)
physical_expert_id
,
num_local_physical_experts
)
==
gpu_id
]
if
len
(
same_gpu_physical_expert_ids
)
>
0
:
output_partial
[
gpu_id
]
=
same_gpu_physical_expert_ids
[
0
]
num_remain
=
torch
.
sum
(
output_partial
==
-
1
).
item
()
output_partial
[
output_partial
==
-
1
]
=
torch
.
tensor
(
_fair_choices
(
candidate_physical_expert_ids
,
k
=
num_remain
,
r
=
r
),
dtype
=
dtype
,
)
assert
torch
.
all
(
logical_to_rank_dispatch_physical_map
!=
-
1
)
assert
torch
.
all
(
logical_to_rank_dispatch_physical_map
!=
-
1
)
return
logical_to_rank_dispatch_physical_map
device
=
logical_to_all_physical_map
.
device
return
logical_to_rank_dispatch_physical_map
[
ep_rank
,
:,
:].
to
(
device
)
def
_logical_to_all_physical_raw
(
logical_to_all_physical_map
,
layer_id
:
int
,
logical_expert_id
:
int
)
->
List
[
int
]:
return
[
physical_expert_id
for
physical_expert_id
in
logical_to_all_physical_map
[
layer_id
,
logical_expert_id
].
tolist
()
if
physical_expert_id
!=
-
1
]
def
_compute_gpu_id_of_physical_expert
(
physical_expert_id
:
int
,
num_local_physical_experts
:
int
)
->
int
:
return
physical_expert_id
//
num_local_physical_experts
def
_fair_choices
(
arr
:
List
,
k
:
int
,
r
:
random
.
Random
)
->
List
:
quotient
,
remainder
=
divmod
(
k
,
len
(
arr
))
ans
=
arr
*
quotient
+
r
.
sample
(
arr
,
k
=
remainder
)
r
.
shuffle
(
ans
)
return
ans
@
dataclass
@
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment