Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fc992a09
Unverified
Commit
fc992a09
authored
May 22, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Support updating expert locations dynamically (#6388)
parent
121f92c5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
723 additions
and
0 deletions
+723
-0
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+21
-0
python/sglang/srt/model_executor/expert_location_updater.py
python/sglang/srt/model_executor/expert_location_updater.py
+420
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+12
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+15
-0
test/srt/test_expert_location_updater.py
test/srt/test_expert_location_updater.py
+255
-0
No files found.
python/sglang/srt/managers/expert_location.py
View file @
fc992a09
...
...
@@ -22,6 +22,7 @@ import torch.distributed
import
torch.nn.functional
as
F
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers
import
deepseek_eplb
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -207,6 +208,26 @@ class ExpertLocationMetadata:
),
)
# -------------------------------- mutation ------------------------------------
def
update
(
self
,
other
:
"ExpertLocationMetadata"
,
):
for
field
in
[
"ep_size"
,
]:
assert
getattr
(
self
,
field
)
==
getattr
(
other
,
field
)
for
field
in
[
"physical_to_logical_map"
,
"logical_to_all_physical_map"
,
"logical_to_all_physical_map_num_valid"
,
"logical_to_rank_dispatch_physical_map"
,
]:
dst
=
getattr
(
self
,
field
)
dst
[...]
=
getattr
(
other
,
field
)
# -------------------------------- usage ------------------------------------
def
logical_to_all_physical
(
...
...
python/sglang/srt/model_executor/expert_location_updater.py
0 → 100644
View file @
fc992a09
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
logging
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch.distributed
from
torch.distributed
import
P2POp
from
sglang.srt.managers.expert_location
import
(
ExpertLocationMetadata
,
get_global_expert_location_metadata
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
update_expert_location
(
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
new_expert_location_metadata
:
ExpertLocationMetadata
,
nnodes
:
int
,
rank
:
int
,
):
old_expert_location_metadata
=
get_global_expert_location_metadata
()
_update_expert_weights
(
routed_experts_weights_of_layer
,
old_expert_location_metadata
,
new_expert_location_metadata
,
nnodes
,
rank
,
)
old_expert_location_metadata
.
update
(
new_expert_location_metadata
)
def
_update_expert_weights
(
routed_experts_weights_of_layer
:
Dict
[
int
,
List
[
torch
.
Tensor
]],
old_expert_location_metadata
:
ExpertLocationMetadata
,
new_expert_location_metadata
:
ExpertLocationMetadata
,
nnodes
:
int
,
rank
:
int
,
):
temp_buffers
=
create_temp_buffers
(
next
(
iter
(
routed_experts_weights_of_layer
.
values
()))
)
world_size
=
torch
.
distributed
.
get_world_size
()
num_local_physical_experts
=
old_expert_location_metadata
.
num_local_physical_experts
num_gpu_per_node
=
world_size
//
nnodes
old_physical_to_logical_map
=
(
old_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
new_physical_to_logical_map
=
(
new_expert_location_metadata
.
physical_to_logical_map
.
tolist
()
)
for
layer_id
in
sorted
(
routed_experts_weights_of_layer
.
keys
()):
update_expert_weights_single_layer
(
routed_experts_weights
=
routed_experts_weights_of_layer
[
layer_id
],
temp_buffers
=
temp_buffers
,
old_physical_to_logical_map
=
old_physical_to_logical_map
[
layer_id
],
new_physical_to_logical_map
=
new_physical_to_logical_map
[
layer_id
],
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
)
def
create_temp_buffers
(
sample_tensors
):
return
[
torch
.
empty_like
(
tensor
)
for
tensor
in
sample_tensors
]
def
update_expert_weights_single_layer
(
routed_experts_weights
:
List
[
torch
.
Tensor
],
temp_buffers
:
List
[
torch
.
Tensor
],
old_physical_to_logical_map
:
List
[
int
],
# (num_physical_Experts,)
new_physical_to_logical_map
:
List
[
int
],
# (num_physical_Experts,)
num_local_physical_experts
:
int
,
num_gpu_per_node
:
int
,
rank
:
int
,
debug
:
bool
=
False
,
):
assert
all
(
tensor
.
shape
[
0
]
==
num_local_physical_experts
for
tensor
in
routed_experts_weights
),
f
"
{
num_local_physical_experts
=
}
{
[
x
.
shape
for
x
in
routed_experts_weights
]
=
}
"
output_logs
=
[]
if
debug
else
None
num_physical_experts
=
len
(
old_physical_to_logical_map
)
num_tensors
=
len
(
routed_experts_weights
)
self_node_id
=
rank
//
num_gpu_per_node
local_expert_location_range
=
(
rank
*
num_local_physical_experts
,
(
rank
+
1
)
*
num_local_physical_experts
,
)
def
_entrypoint
():
# List[Tuple[logical_expert_id, List[P2POp]]]
p2p_op_infos
:
List
[
Tuple
[
int
,
List
[
P2POp
]]]
=
[]
# List[Tuple[temp_buffers_expert_location, routed_experts_weights_expert_location]]
buffer2weight_copy_infos
:
List
[
Tuple
[
int
,
int
]]
=
[]
_handle_recv
(
buffer2weight_copy_infos
,
p2p_op_infos
)
_create_isend_ops
(
p2p_op_infos
)
_execute_p2p_ops
(
p2p_op_infos
)
_execute_buffer2weight_copies
(
buffer2weight_copy_infos
)
if
debug
:
output_logs
.
append
(
f
"
{
p2p_op_infos
=
}
"
)
output_logs
.
append
(
f
"
{
buffer2weight_copy_infos
=
}
"
)
def
_handle_recv
(
buffer2weight_copy_infos
,
p2p_op_infos
):
for
dst_expert_location
in
range
(
*
local_expert_location_range
):
_handle_recv_of_dst_expert_location
(
dst_expert_location
,
buffer2weight_copy_infos
,
p2p_op_infos
)
def
_handle_recv_of_dst_expert_location
(
dst_expert_location
:
int
,
buffer2weight_copy_infos
,
p2p_op_infos
):
logical_expert_id
=
new_physical_to_logical_map
[
dst_expert_location
]
# case 1: unchanged
if
old_physical_to_logical_map
[
dst_expert_location
]
==
logical_expert_id
:
if
debug
:
output_logs
.
append
(
f
"handle_recv_of_dst_expert_location
{
dst_expert_location
=
}
case=unchanged"
)
return
# case 2: same-gpu
for
src_expert_location
in
range
(
*
local_expert_location_range
):
if
old_physical_to_logical_map
[
src_expert_location
]
==
logical_expert_id
:
for
i
in
range
(
num_tensors
):
_get_tensor
(
temp_buffers
,
i
,
dst_expert_location
).
copy_
(
_get_tensor
(
routed_experts_weights
,
i
,
src_expert_location
)
)
buffer2weight_copy_infos
.
append
(
(
dst_expert_location
,
dst_expert_location
)
)
if
debug
:
output_logs
.
append
(
f
"handle_recv_of_dst_expert_location
{
dst_expert_location
=
}
case=same-gpu
{
src_expert_location
=
}
"
)
return
# case 3: free-rider
for
src_expert_location
in
range
(
rank
*
num_local_physical_experts
,
dst_expert_location
):
if
new_physical_to_logical_map
[
src_expert_location
]
==
logical_expert_id
:
buffer2weight_copy_infos
.
append
(
(
src_expert_location
,
dst_expert_location
)
)
if
debug
:
output_logs
.
append
(
f
"handle_recv_of_dst_expert_location
{
dst_expert_location
=
}
case=free-rider
{
src_expert_location
=
}
"
)
return
same_node_mapping
,
cross_node_mapping
,
need_comm_self_node_dst_ranks
=
(
_compute_comm_info
(
logical_expert_id
=
logical_expert_id
)
)
# case 4: same-node
if
rank
in
need_comm_self_node_dst_ranks
:
chosen_src_rank
=
same_node_mapping
.
chunk_value_from_element_value
(
element_value
=
rank
)
_create_p2p_recv_and_buffer2weight_copy
(
buffer2weight_copy_infos
,
p2p_op_infos
,
src_rank
=
chosen_src_rank
,
logical_expert_id
=
logical_expert_id
,
dst_expert_location
=
dst_expert_location
,
)
if
debug
:
output_logs
.
append
(
f
"handle_recv_of_dst_expert_location
{
dst_expert_location
=
}
case=same-node
{
chosen_src_rank
=
}
"
)
return
# case 5: cross-node
# Future work: can optimize when there are multiple ranks in the same dst node that uses the same logical expert
chosen_src_rank
=
cross_node_mapping
.
chunk_value_from_element_value
(
element_value
=
rank
)
_create_p2p_recv_and_buffer2weight_copy
(
buffer2weight_copy_infos
,
p2p_op_infos
,
src_rank
=
chosen_src_rank
,
logical_expert_id
=
logical_expert_id
,
dst_expert_location
=
dst_expert_location
,
)
if
debug
:
output_logs
.
append
(
f
"handle_recv_of_dst_expert_location
{
dst_expert_location
=
}
case=cross-node
{
chosen_src_rank
=
}
"
)
return
def
_create_p2p_recv_and_buffer2weight_copy
(
buffer2weight_copy_infos
,
p2p_op_infos
,
*
,
logical_expert_id
:
int
,
src_rank
:
int
,
dst_expert_location
:
int
,
):
p2p_op_infos
.
append
(
(
logical_expert_id
,
[
P2POp
(
op
=
torch
.
distributed
.
irecv
,
tensor
=
_get_tensor
(
temp_buffers
,
i
,
dst_expert_location
),
peer
=
src_rank
,
)
for
i
in
range
(
num_tensors
)
],
)
)
buffer2weight_copy_infos
.
append
((
dst_expert_location
,
dst_expert_location
))
def
_create_isend_ops
(
p2p_op_infos
):
handled_logical_expert_ids
=
set
()
for
src_expert_location
in
range
(
*
local_expert_location_range
):
logical_expert_id
=
old_physical_to_logical_map
[
src_expert_location
]
if
logical_expert_id
in
handled_logical_expert_ids
:
continue
handled_logical_expert_ids
.
add
(
logical_expert_id
)
_create_isend_ops_of_logical_expert_id
(
logical_expert_id
,
src_expert_location
,
p2p_op_infos
)
def
_create_isend_ops_of_logical_expert_id
(
logical_expert_id
,
src_expert_location
,
p2p_op_infos
):
same_node_mapping
,
cross_node_mapping
,
need_comm_self_node_dst_ranks
=
(
_compute_comm_info
(
logical_expert_id
=
logical_expert_id
)
)
same_node_dst_ranks
=
same_node_mapping
.
element_values_from_chunk_value
(
chunk_value
=
rank
)
cross_node_dst_ranks
=
cross_node_mapping
.
element_values_from_chunk_value
(
chunk_value
=
rank
)
all_dst_ranks
=
same_node_dst_ranks
+
cross_node_dst_ranks
if
debug
:
output_logs
.
append
(
f
"create_isend_ops_of_logical_expert_id
{
logical_expert_id
=
}
{
src_expert_location
=
}
{
same_node_dst_ranks
=
}
{
cross_node_dst_ranks
=
}
"
)
p2p_op_infos
.
append
(
(
logical_expert_id
,
[
P2POp
(
op
=
torch
.
distributed
.
isend
,
tensor
=
_get_tensor
(
routed_experts_weights
,
i
,
src_expert_location
),
peer
=
dst_rank
,
)
for
dst_rank
in
all_dst_ranks
for
i
in
range
(
num_tensors
)
],
)
)
def
_compute_comm_info
(
logical_expert_id
:
int
):
all_src_ranks
=
_deduplicate_ordered
(
[
x
//
num_local_physical_experts
for
x
in
range
(
num_physical_experts
)
if
old_physical_to_logical_map
[
x
]
==
logical_expert_id
]
)
all_src_nodes
=
[
x
//
num_gpu_per_node
for
x
in
all_src_ranks
]
self_node_src_ranks
=
[
x
for
x
in
all_src_ranks
if
x
//
num_gpu_per_node
==
self_node_id
]
need_comm_dst_ranks
=
_deduplicate_ordered
(
[
x
//
num_local_physical_experts
for
x
in
range
(
num_physical_experts
)
if
new_physical_to_logical_map
[
x
]
==
logical_expert_id
and
x
//
num_local_physical_experts
not
in
all_src_ranks
]
)
need_comm_self_node_dst_ranks
=
(
[
x
for
x
in
need_comm_dst_ranks
if
x
//
num_gpu_per_node
==
self_node_id
]
if
len
(
self_node_src_ranks
)
>
0
else
[]
)
need_comm_cross_node_dst_ranks
=
[
x
for
x
in
need_comm_dst_ranks
if
(
x
//
num_gpu_per_node
)
not
in
all_src_nodes
]
same_node_mapping
=
_ChunkUtils
(
chunk_values
=
self_node_src_ranks
,
element_values
=
need_comm_self_node_dst_ranks
,
)
cross_node_mapping
=
_ChunkUtils
(
chunk_values
=
all_src_ranks
,
element_values
=
need_comm_cross_node_dst_ranks
,
)
return
same_node_mapping
,
cross_node_mapping
,
need_comm_self_node_dst_ranks
def
_execute_p2p_ops
(
p2p_op_infos
):
sorted_infos
=
sorted
(
p2p_op_infos
,
key
=
lambda
info
:
info
[
0
])
p2p_ops
=
[
op
for
_
,
ops
in
sorted_infos
for
op
in
ops
]
if
len
(
p2p_ops
)
==
0
:
return
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
p2p_ops
)
for
req
in
reqs
:
req
.
wait
()
def
_execute_buffer2weight_copies
(
buffer2weight_copy_infos
):
for
(
temp_buffers_expert_location
,
routed_experts_weights_expert_location
,
)
in
buffer2weight_copy_infos
:
for
i
in
range
(
num_tensors
):
_get_tensor
(
routed_experts_weights
,
i
,
routed_experts_weights_expert_location
).
copy_
(
_get_tensor
(
temp_buffers
,
i
,
temp_buffers_expert_location
))
def
_get_tensor
(
tensors
,
tensor_index
:
int
,
expert_location
:
int
)
->
torch
.
Tensor
:
return
tensors
[
tensor_index
][
_get_local_expert_location
(
expert_location
)]
def
_get_local_expert_location
(
expert_location
:
int
)
->
int
:
assert
(
local_expert_location_range
[
0
]
<=
expert_location
<
local_expert_location_range
[
1
]
)
return
expert_location
%
num_local_physical_experts
_entrypoint
()
return
output_logs
class
_ChunkUtils
:
def
__init__
(
self
,
*
,
chunk_values
:
List
,
element_values
:
List
):
self
.
chunk_values
=
chunk_values
self
.
element_values
=
element_values
def
chunk_value_from_element_value
(
self
,
element_value
):
chunk_index
=
self
.
_chunk_index_from_element_index
(
num_elements
=
len
(
self
.
element_values
),
num_chunks
=
len
(
self
.
chunk_values
),
element_index
=
self
.
element_values
.
index
(
element_value
),
)
return
self
.
chunk_values
[
chunk_index
]
def
element_values_from_chunk_value
(
self
,
chunk_value
)
->
List
:
if
len
(
self
.
element_values
)
==
0
:
return
[]
element_slice
=
self
.
_element_slice_from_chunk_index
(
num_elements
=
len
(
self
.
element_values
),
num_chunks
=
len
(
self
.
chunk_values
),
chunk_index
=
self
.
chunk_values
.
index
(
chunk_value
),
)
return
self
.
element_values
[
element_slice
]
@
staticmethod
def
_chunk_index_from_element_index
(
num_elements
:
int
,
num_chunks
:
int
,
element_index
:
int
)
->
int
:
short_chunk_size
,
num_long_chunks
=
divmod
(
num_elements
,
num_chunks
)
num_elements_for_long_chunks
=
num_long_chunks
*
(
short_chunk_size
+
1
)
if
element_index
<
num_elements_for_long_chunks
:
return
element_index
//
(
short_chunk_size
+
1
)
else
:
return
(
num_long_chunks
+
(
element_index
-
num_elements_for_long_chunks
)
//
short_chunk_size
)
@
staticmethod
def
_element_slice_from_chunk_index
(
num_elements
:
int
,
num_chunks
:
int
,
chunk_index
:
int
)
->
slice
:
short_chunk_size
,
num_long_chunks
=
divmod
(
num_elements
,
num_chunks
)
start
=
chunk_index
*
short_chunk_size
+
min
(
chunk_index
,
num_long_chunks
)
end
=
start
+
short_chunk_size
+
int
(
chunk_index
<
num_long_chunks
)
return
slice
(
start
,
end
)
def
_deduplicate_ordered
(
arr
:
List
[
int
]):
output
=
[]
for
item
in
arr
:
if
len
(
output
)
==
0
or
item
!=
output
[
-
1
]:
output
.
append
(
item
)
return
output
python/sglang/srt/model_executor/model_runner.py
View file @
fc992a09
...
...
@@ -57,6 +57,7 @@ from sglang.srt.managers.expert_distribution import (
set_global_expert_distribution_recorder
,
)
from
sglang.srt.managers.expert_location
import
(
ExpertLocationMetadata
,
compute_initial_expert_location_metadata
,
get_global_expert_location_metadata
,
set_global_expert_location_metadata
,
...
...
@@ -70,6 +71,7 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.model_executor
import
expert_location_updater
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader
import
get_model
...
...
@@ -575,6 +577,16 @@ class ModelRunner:
f
"TP rank
{
self
.
tp_rank
}
could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
)
from
None
def
update_expert_location
(
self
,
new_expert_location_metadata
:
ExpertLocationMetadata
):
expert_location_updater
.
update_expert_location
(
self
.
model
.
routed_experts_weights_of_layer
,
new_expert_location_metadata
,
nnodes
=
self
.
server_args
.
nnodes
,
rank
=
self
.
tp_rank
,
)
def
update_weights_from_disk
(
self
,
model_path
:
str
,
load_format
:
str
)
->
tuple
[
bool
,
str
]:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
fc992a09
...
...
@@ -317,6 +317,13 @@ class DeepseekV2MoE(nn.Module):
def
_enable_deepep_moe
(
self
):
return
global_server_args_dict
[
"enable_deepep_moe"
]
def
get_moe_weights
(
self
):
return
[
x
.
data
for
name
,
x
in
self
.
experts
.
named_parameters
()
if
name
not
in
[
"correction_bias"
]
]
def
op_gate
(
self
,
state
):
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
...
...
@@ -1599,6 +1606,14 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
use_deep_gemm_bmm
=
True
# TODO support nextn later
if
not
is_nextn
:
self
.
routed_experts_weights_of_layer
=
{
layer_id
:
layer
.
mlp
.
get_moe_weights
()
for
layer_id
,
layer
in
enumerate
(
self
.
model
.
layers
)
if
isinstance
(
layer
.
mlp
,
DeepseekV2MoE
)
}
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
...
...
test/srt/test_expert_location_updater.py
0 → 100644
View file @
fc992a09
import
os
import
traceback
import
unittest
from
dataclasses
import
dataclass
from
typing
import
List
import
torch
import
torch.distributed
import
torch.multiprocessing
as
mp
from
torch.multiprocessing
import
Process
from
sglang.srt.model_executor
import
expert_location_updater
from
sglang.test.test_utils
import
CustomTestCase
,
find_available_port
from
sglang.utils
import
is_in_ci
@
dataclass
class
_TestInfo
:
nnodes
:
int
num_logical_experts
:
int
num_physical_experts
:
int
num_repeat
:
int
=
5000
class
TestExpertLocationUpdater
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
test_cpu
(
self
):
self
.
_test_common
(
device
=
"cpu"
)
self
.
_test_core
(
num_gpus
=
32
,
device
=
"cpu"
,
infos
=
[
_TestInfo
(
nnodes
=
4
,
num_logical_experts
=
256
,
num_physical_experts
=
288
,
num_repeat
=
10000
,
)
],
)
def
test_cpu_slow
(
self
):
if
is_in_ci
():
return
self
.
_test_core
(
num_gpus
=
144
,
device
=
"cpu"
,
infos
=
[
_TestInfo
(
nnodes
=
18
,
num_logical_experts
=
256
,
num_physical_experts
=
288
,
num_repeat
=
10000
,
)
],
)
def
test_gpu
(
self
):
if
is_in_ci
():
return
self
.
_test_common
(
device
=
"cuda"
)
def
_test_common
(
self
,
device
):
infos
=
[]
for
nnodes
in
[
1
,
2
,
4
]:
for
num_logical_experts
in
[
2
,
5
,
20
,
256
]:
for
num_physical_experts
in
[
8
,
16
,
256
,
288
]:
if
num_logical_experts
>
num_physical_experts
:
continue
infos
.
append
(
_TestInfo
(
nnodes
=
nnodes
,
num_logical_experts
=
num_logical_experts
,
num_physical_experts
=
num_physical_experts
,
)
)
self
.
_test_core
(
num_gpus
=
8
,
device
=
device
,
infos
=
infos
)
def
_test_core
(
self
,
num_gpus
:
int
,
device
:
str
,
infos
:
List
[
_TestInfo
],
):
master_port
=
find_available_port
(
23456
)
processes
=
[]
output_reader
,
output_writer
=
mp
.
Pipe
(
duplex
=
False
)
for
rank
in
range
(
num_gpus
):
p
=
Process
(
target
=
_run_subprocess
,
kwargs
=
dict
(
rank
=
rank
,
num_gpus
=
num_gpus
,
output_writer
=
output_writer
,
master_port
=
master_port
,
device
=
device
,
infos
=
infos
,
),
)
p
.
start
()
processes
.
append
(
p
)
for
_
in
range
(
num_gpus
):
self
.
assertTrue
(
output_reader
.
recv
(),
f
"Subprocess has error, please see logs above."
)
for
p
in
processes
:
p
.
join
()
def
_run_subprocess
(
rank
:
int
,
num_gpus
:
int
,
master_port
:
int
,
device
:
str
,
infos
:
List
[
_TestInfo
],
output_writer
,
):
try
:
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
str
(
master_port
)
torch
.
random
.
manual_seed
(
42
)
torch
.
distributed
.
init_process_group
(
rank
=
rank
,
world_size
=
num_gpus
,
backend
=
{
"cpu"
:
"gloo"
,
"cuda"
:
None
}[
device
],
)
if
device
==
"cuda"
:
torch
.
cuda
.
set_device
(
f
"cuda:
{
rank
}
"
)
for
info
in
infos
:
_execute_test
(
info
,
rank
=
rank
,
num_gpus
=
num_gpus
,
device
=
device
)
execution_ok
=
True
except
Exception
as
e
:
print
(
f
"subprocess[
{
rank
=
}
] has error:
{
e
}
"
,
flush
=
True
)
traceback
.
print_exc
()
execution_ok
=
False
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
def
_execute_test
(
info
:
_TestInfo
,
rank
:
int
,
num_gpus
:
int
,
device
:
str
):
if
rank
==
0
:
print
(
f
"Test:
{
num_gpus
=
}
{
info
=
}
"
,
flush
=
True
)
assert
info
.
num_physical_experts
%
num_gpus
==
0
num_local_physical_experts
=
info
.
num_physical_experts
//
num_gpus
assert
num_gpus
%
info
.
nnodes
==
0
num_gpu_per_node
=
num_gpus
//
info
.
nnodes
def
_create_routed_experts_weights
(
physical_to_logical_map
):
local_logical_expert_ids
=
physical_to_logical_map
[
rank
*
num_local_physical_experts
:
(
rank
+
1
)
*
num_local_physical_experts
].
cpu
()
return
[
local_logical_expert_ids
.
to
(
device
).
clone
(),
torch
.
tensor
(
[
[
local_logical_expert_id
*
10
,
local_logical_expert_id
*
100
]
for
local_logical_expert_id
in
local_logical_expert_ids
.
tolist
()
],
device
=
device
,
),
]
def
_create_physical_to_logical_map
():
if
rank
==
0
:
ans
=
torch
.
concat
(
[
torch
.
arange
(
0
,
info
.
num_logical_experts
),
torch
.
randint
(
0
,
info
.
num_logical_experts
,
(
info
.
num_physical_experts
-
info
.
num_logical_experts
,),
),
]
)
ans
=
ans
[
torch
.
randperm
(
ans
.
shape
[
0
])]
else
:
ans
=
torch
.
empty
((
info
.
num_physical_experts
,),
dtype
=
torch
.
int64
)
assert
ans
.
dtype
==
torch
.
int64
and
ans
.
shape
==
(
info
.
num_physical_experts
,)
ans
=
ans
.
to
(
device
)
torch
.
distributed
.
broadcast
(
ans
,
src
=
0
)
return
ans
.
cpu
()
physical_to_logical_map
=
_create_physical_to_logical_map
()
routed_experts_weights
=
_create_routed_experts_weights
(
physical_to_logical_map
)
for
i
in
range
(
info
.
num_repeat
):
if
rank
==
0
and
((
i
%
500
==
0
)
or
(
i
==
info
.
num_repeat
-
1
)):
print
(
f
"Step
{
i
}
/
{
info
.
num_repeat
}
"
,
flush
=
True
)
new_physical_to_logical_map
=
_create_physical_to_logical_map
()
expect_new_weights
=
_create_routed_experts_weights
(
new_physical_to_logical_map
)
output_logs
=
expert_location_updater
.
update_expert_weights_single_layer
(
routed_experts_weights
=
routed_experts_weights
,
temp_buffers
=
expert_location_updater
.
create_temp_buffers
(
routed_experts_weights
),
old_physical_to_logical_map
=
physical_to_logical_map
,
new_physical_to_logical_map
=
new_physical_to_logical_map
,
num_local_physical_experts
=
num_local_physical_experts
,
num_gpu_per_node
=
num_gpu_per_node
,
rank
=
rank
,
debug
=
True
,
)
local_has_error
=
not
all
(
torch
.
all
(
x
==
y
)
for
x
,
y
in
zip
(
routed_experts_weights
,
expect_new_weights
,
strict
=
True
)
)
global_has_error
=
torch
.
tensor
(
local_has_error
,
device
=
device
)
torch
.
distributed
.
all_reduce
(
global_has_error
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
if
global_has_error
.
cpu
().
item
():
output_logs_str
=
"
\n
"
.
join
(
output_logs
)
local_message
=
(
f
"===================== rank
{
rank
}
============================
\n
"
f
"
{
num_gpus
=
}
{
info
=
}
\n
"
f
"
{
routed_experts_weights
[
0
].
tolist
()
=
}
\n
"
f
"
{
expect_new_weights
[
0
].
tolist
()
=
}
\n
"
f
"
{
physical_to_logical_map
.
tolist
()
=
}
\n
"
f
"
{
new_physical_to_logical_map
.
tolist
()
=
}
\n
"
f
"===logs===
\n
"
f
"
{
output_logs_str
}
\n
"
f
"==============================================================
\n
"
)
global_messages
=
([
None
]
*
num_gpus
)
if
rank
==
0
else
None
torch
.
distributed
.
gather_object
(
local_message
,
global_messages
,
dst
=
0
)
if
rank
==
0
:
print
(
"
\n\n
"
.
join
(
global_messages
),
flush
=
True
)
raise
AssertionError
(
f
"Error happens, see logs above"
)
physical_to_logical_map
=
new_physical_to_logical_map
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment