Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6df81e8a
"tests/vscode:/vscode.git/clone" did not exist on "805aa93789fe9c95dd8d5a3ceac100d33f584ec7"
Unverified
Commit
6df81e8a
authored
May 29, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Support tuning DeepEP configs (#6742)
parent
3ab7d9b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
694 additions
and
0 deletions
+694
-0
benchmark/kernels/deepep/deepep_utils.py
benchmark/kernels/deepep/deepep_utils.py
+218
-0
benchmark/kernels/deepep/tuning_deepep.py
benchmark/kernels/deepep/tuning_deepep.py
+476
-0
No files found.
benchmark/kernels/deepep/deepep_utils.py
0 → 100644
View file @
6df81e8a
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
import
os
import
sys
from
typing
import
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
def
init_dist
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
ip
=
args
.
master_addr
port
=
args
.
master_port
num_nodes
=
args
.
nnodes
node_rank
=
args
.
node_rank
assert
(
num_local_ranks
<
8
and
num_nodes
==
1
)
or
num_local_ranks
==
8
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
,
world_size
=
num_nodes
*
num_local_ranks
,
rank
=
node_rank
*
num_local_ranks
+
local_rank
,
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
set_device
(
local_rank
)
return
(
dist
.
get_rank
(),
dist
.
get_world_size
(),
dist
.
new_group
(
list
(
range
(
num_local_ranks
*
num_nodes
))),
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
()
+
1
,
y
.
double
()
+
1
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
(
1
-
sim
).
item
()
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
def
inplace_unique
(
x
:
torch
.
Tensor
,
num_slots
:
int
):
assert
x
.
dim
()
==
2
mask
=
x
<
0
x_padded
=
x
.
masked_fill
(
mask
,
num_slots
)
bin_count
=
torch
.
zeros
((
x
.
size
(
0
),
num_slots
+
1
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
bin_count
.
scatter_add_
(
1
,
x_padded
,
torch
.
ones_like
(
x_padded
))
bin_count
=
bin_count
[:,
:
num_slots
]
sorted_bin_count
,
sorted_bin_idx
=
torch
.
sort
(
bin_count
,
dim
=-
1
,
descending
=
True
)
sorted_bin_idx
.
masked_fill_
(
sorted_bin_count
==
0
,
-
1
)
sorted_bin_idx
=
torch
.
sort
(
sorted_bin_idx
,
descending
=
True
,
dim
=-
1
).
values
x
[:,
:].
fill_
(
-
1
)
valid_len
=
min
(
num_slots
,
x
.
size
(
1
))
x
[:,
:
valid_len
]
=
sorted_bin_idx
[:,
:
valid_len
]
def
create_grouped_scores
(
scores
:
torch
.
Tensor
,
group_idx
:
torch
.
Tensor
,
num_groups
:
int
):
num_tokens
,
num_experts
=
scores
.
shape
scores
=
scores
.
view
(
num_tokens
,
num_groups
,
-
1
)
mask
=
torch
.
zeros
((
num_tokens
,
num_groups
),
dtype
=
torch
.
bool
,
device
=
scores
.
device
)
mask
=
mask
.
scatter_
(
1
,
group_idx
,
True
).
unsqueeze
(
-
1
).
expand_as
(
scores
)
return
(
scores
*
mask
).
view
(
num_tokens
,
num_experts
)
def
bench
(
fn
,
num_warmups
:
int
=
20
,
num_tests
:
int
=
30
,
post_fn
=
None
):
# Flush L2 cache with 256 MB data
torch
.
cuda
.
synchronize
()
cache
=
torch
.
empty
(
int
(
256e6
//
4
),
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
num_warmups
):
fn
()
# Flush L2
cache
.
zero_
()
# Testing
start_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
end_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
for
i
in
range
(
num_tests
):
# Record
start_events
[
i
].
record
()
fn
()
end_events
[
i
].
record
()
if
post_fn
is
not
None
:
post_fn
()
torch
.
cuda
.
synchronize
()
times
=
np
.
array
(
[
s
.
elapsed_time
(
e
)
/
1e3
for
s
,
e
in
zip
(
start_events
,
end_events
)]
)[
1
:]
return
np
.
average
(
times
),
np
.
min
(
times
),
np
.
max
(
times
)
class
empty_suppress
:
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
_
):
pass
class
suppress_stdout_stderr
:
def
__enter__
(
self
):
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
self
.
old_stderr_fileno_undup
=
sys
.
stderr
.
fileno
()
self
.
old_stdout_fileno
=
os
.
dup
(
sys
.
stdout
.
fileno
())
self
.
old_stderr_fileno
=
os
.
dup
(
sys
.
stderr
.
fileno
())
self
.
old_stdout
=
sys
.
stdout
self
.
old_stderr
=
sys
.
stderr
os
.
dup2
(
self
.
outnull_file
.
fileno
(),
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
errnull_file
.
fileno
(),
self
.
old_stderr_fileno_undup
)
sys
.
stdout
=
self
.
outnull_file
sys
.
stderr
=
self
.
errnull_file
return
self
def
__exit__
(
self
,
*
_
):
sys
.
stdout
=
self
.
old_stdout
sys
.
stderr
=
self
.
old_stderr
os
.
dup2
(
self
.
old_stdout_fileno
,
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
old_stderr_fileno
,
self
.
old_stderr_fileno_undup
)
os
.
close
(
self
.
old_stdout_fileno
)
os
.
close
(
self
.
old_stderr_fileno
)
self
.
outnull_file
.
close
()
self
.
errnull_file
.
close
()
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
Optional
[
str
]
=
None
,
barrier_comm_profiling
:
bool
=
False
,
):
# Profile
suppress
=
suppress_stdout_stderr
if
suppress_kineto_output
else
empty_suppress
with
suppress
():
schedule
=
torch
.
profiler
.
schedule
(
wait
=
0
,
warmup
=
1
,
active
=
1
,
repeat
=
1
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
],
schedule
=
schedule
)
as
prof
:
for
i
in
range
(
2
):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if
barrier_comm_profiling
:
lhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
rhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
lhs
@
rhs
dist
.
all_reduce
(
torch
.
ones
(
1
,
dtype
=
torch
.
float
,
device
=
"cuda"
))
for
_
in
range
(
num_tests
):
fn
()
prof
.
step
()
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tupled
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
(
prof
.
key_averages
()
.
table
(
sort_by
=
"cuda_time_total"
,
max_name_column_width
=
100
)
.
split
(
"
\n
"
)
)
kernel_names
=
(
kernel_names
,)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table"
# Save chrome traces
if
trace_path
is
not
None
:
prof
.
export_chrome_trace
(
trace_path
)
# Return average kernel times
units
=
{
"ms"
:
1e3
,
"us"
:
1e6
}
kernel_times
=
[]
for
name
in
kernel_names
:
for
line
in
prof_lines
:
if
name
in
line
:
time_str
=
line
.
split
()[
-
2
]
for
unit
,
scale
in
units
.
items
():
if
unit
in
time_str
:
kernel_times
.
append
(
float
(
time_str
.
replace
(
unit
,
""
))
/
scale
)
break
break
return
tuple
(
kernel_times
)
if
is_tupled
else
kernel_times
[
0
]
def
hash_tensor
(
t
:
torch
.
Tensor
):
return
t
.
view
(
torch
.
int64
).
sum
().
item
()
benchmark/kernels/deepep/tuning_deepep.py
0 → 100644
View file @
6df81e8a
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
"""
Example usage:
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
Then check `deepep_tuned.json`
"""
import
argparse
import
json
import
time
from
copy
import
deepcopy
from
pathlib
import
Path
# noinspection PyUnresolvedReferences
import
deep_ep
import
torch
import
torch.distributed
as
dist
from
deepep_utils
import
(
bench
,
calc_diff
,
create_grouped_scores
,
init_dist
,
inplace_unique
,
per_token_cast_back
,
per_token_cast_to_fp8
,
)
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_local_ranks
:
int
,
num_ranks
:
int
,
num_nodes
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
args
,
):
# Settings
num_tokens
,
hidden
,
num_topk_groups
,
num_topk
,
num_experts
=
(
4096
,
7168
,
min
(
num_nodes
,
4
),
8
,
(
256
//
num_ranks
)
*
num_ranks
,
)
assert
num_experts
%
num_ranks
==
0
and
num_local_ranks
==
8
if
local_rank
==
0
:
print
(
f
"[config] num_tokens=
{
num_tokens
}
, hidden=
{
hidden
}
, num_topk_groups=
{
num_topk_groups
}
, num_topk=
{
num_topk
}
"
,
flush
=
True
,
)
# Random data
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
scores
=
(
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1
)
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
masked_scores
=
create_grouped_scores
(
scores
,
group_idx
,
num_nodes
)
topk_idx
=
torch
.
topk
(
masked_scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
(
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
rank
)
topk_weights_pure_rand
=
torch
.
randn
(
(
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rank_idx
=
topk_idx
//
(
num_experts
//
num_ranks
)
rank_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rank_idx
,
num_ranks
)
rdma_rank_idx
=
rank_idx
//
num_local_ranks
rdma_rank_idx
.
masked_fill_
(
rank_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_rank_idx
,
num_nodes
)
# RDMA dispatch counts
rdma_idx
=
topk_idx
//
(
num_experts
//
num_nodes
)
rdma_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_idx
,
num_nodes
)
num_rdma_token_sent
=
rdma_idx
.
ne
(
-
1
).
sum
().
item
()
# Expert meta
num_tokens_per_expert
=
torch
.
zeros
((
num_experts
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
i
in
range
(
num_experts
):
num_tokens_per_expert
[
i
]
=
(
topk_idx
==
i
).
sum
()
gbl_num_tokens_per_expert
=
num_tokens_per_expert
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_expert
,
group
=
group
)
# Rank layout meta
num_tokens_per_rank
=
torch
.
empty
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
num_tokens_per_rdma_rank
=
torch
.
empty
((
num_nodes
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
token_idx_in_rank
=
torch
.
full
(
(
num_ranks
,
num_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_ranks
):
num_tokens_per_rank
[
i
]
=
(
rank_idx
==
i
).
sum
()
token_sel
=
(
rank_idx
==
i
).
max
(
dim
=-
1
)[
0
]
count
=
token_sel
.
sum
().
item
()
tokens
=
torch
.
sort
(
token_sel
.
to
(
torch
.
int
),
descending
=
True
)[
1
]
tokens
[:
count
]
=
torch
.
sort
(
tokens
[:
count
])[
0
]
token_idx_in_rank
[
i
][
tokens
[:
count
]]
=
torch
.
arange
(
count
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_nodes
):
num_tokens_per_rdma_rank
[
i
]
=
(
rdma_rank_idx
==
i
).
sum
()
token_idx_in_rank
=
token_idx_in_rank
.
T
.
contiguous
().
to
(
torch
.
int
)
is_token_in_rank
=
token_idx_in_rank
>=
0
gbl_num_tokens_per_rank
=
num_tokens_per_rank
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_rank
,
group
=
group
)
(
ref_num_tokens_per_rank
,
ref_num_tokens_per_rdma_rank
,
ref_num_tokens_per_expert
,
ref_is_token_in_rank
,
_
,
)
=
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
)
assert
torch
.
allclose
(
ref_num_tokens_per_rank
,
num_tokens_per_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_expert
,
num_tokens_per_expert
)
assert
torch
.
allclose
(
ref_is_token_in_rank
,
is_token_in_rank
)
t
=
bench
(
lambda
:
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
))[
0
]
if
local_rank
==
0
:
print
(
f
"[layout] Kernel performance:
{
t
*
1000
:.
3
f
}
ms"
,
flush
=
True
)
print
(
""
,
flush
=
True
)
group
.
barrier
()
time
.
sleep
(
1
)
# Config
rdma_buffer_size
,
nvl_buffer_size
=
128
,
(
720
if
num_ranks
in
(
144
,
160
)
else
512
)
config
=
deep_ep
.
Config
(
num_sms
,
8
,
nvl_buffer_size
,
16
,
rdma_buffer_size
)
# Test dispatch
# noinspection PyShadowingNames
def
check_data
(
check_x
,
recv_gbl_rank_prefix_sum
):
assert
torch
.
allclose
(
check_x
.
amin
(
dim
=
1
),
check_x
.
amax
(
dim
=
1
))
check_start
=
0
for
i
in
range
(
num_ranks
):
check_end
=
recv_gbl_rank_prefix_sum
[
i
].
item
()
assert
(
check_x
[
check_start
:
check_end
,
:].
int
()
-
i
).
sum
().
item
()
==
0
check_start
=
check_end
for
previous_mode
in
(
False
,
True
):
for
async_mode
in
(
False
,
True
):
for
current_x
in
(
x_pure_rand
,
x
,
x_e4m3
):
for
with_topk
in
(
False
,
True
):
if
local_rank
==
0
:
print
(
f
'[testing] Running with
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
,
{
"with"
if
with_topk
else
"without"
}
top-k (async=
{
async_mode
}
, previous=
{
previous_mode
}
) ...'
,
flush
=
True
,
end
=
""
,
)
dispatch_args
=
{
"x"
:
current_x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
dispatch_args
.
update
(
{
"topk_idx"
:
topk_idx
,
"topk_weights"
:
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
),
}
)
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
(
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
,
)
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
# Checks
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
assert
gbl_num_tokens_per_rank
[
rank
].
item
()
==
recv_x
.
size
(
0
),
f
"
{
gbl_num_tokens_per_rank
[
rank
].
item
()
}
!=
{
recv_x
.
size
(
0
)
}
"
assert
(
gbl_num_tokens_per_expert
.
view
(
num_ranks
,
-
1
)[
rank
].
tolist
()
==
recv_num_tokens_per_expert_list
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
if
with_topk
:
# Check `topk_idx`
assert
(
recv_topk_idx
.
eq
(
-
1
)
|
(
(
recv_topk_idx
>=
0
)
&
(
recv_topk_idx
<
(
num_experts
//
num_ranks
))
)
).
sum
().
item
()
==
recv_topk_idx
.
numel
()
for
i
,
count
in
enumerate
(
recv_num_tokens_per_expert_list
):
assert
recv_topk_idx
.
eq
(
i
).
sum
().
item
()
==
count
# Check `topk_weights`
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
(
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
if
not
with_topk
:
dispatch_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
# Test combine
combine_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
combine_args
.
update
({
"topk_weights"
:
recv_topk_weights
})
if
previous_mode
:
combine_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
check_x
=
combined_x
.
float
()
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
if
with_topk
:
check_topk_weights
=
(
combined_topk_weights
if
(
current_x
is
x_pure_rand
)
else
(
combined_topk_weights
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
)
)
ref_topk_weights
=
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
)
assert
calc_diff
(
check_topk_weights
,
ref_topk_weights
)
<
1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes
=
num_rdma_token_sent
*
hidden
*
2
dispatch_bf16_nvl_recv_bytes
=
recv_x
.
numel
()
*
2
combine_bf16_nvl_send_bytes
=
dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes
=
dispatch_bf16_rdma_send_bytes
if
local_rank
==
0
:
print
(
" passed"
,
flush
=
True
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
output_data
=
{}
# Tune dispatch performance
best_dispatch_results
=
None
fp8_factor
=
(
1
+
4
/
128
)
/
2
for
current_x
in
(
x_e4m3
,
x
):
best_time
,
best_results
=
1e10
,
None
rdma_send_bytes
=
(
(
dispatch_bf16_rdma_send_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes
=
(
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
)
for
nvl_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
config_kwargs
=
{
"num_sms"
:
num_sms
,
"num_max_nvl_chunked_send_tokens"
:
nvl_chunk_size
,
"num_max_nvl_chunked_recv_tokens"
:
nvl_buffer_size
,
"num_max_rdma_chunked_send_tokens"
:
rdma_chunk_size
,
"num_max_rdma_chunked_recv_tokens"
:
rdma_buffer_size
,
}
config
=
deep_ep
.
Config
(
**
config_kwargs
)
tune_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
dispatch
(
**
tune_args
))[
0
]
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
config_kwargs
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
rdma_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
local_rank
==
0
:
print
(
f
'[tuning] Best dispatch (
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
): SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
rdma_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)'
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
is_fp8
=
isinstance
(
current_x
,
tuple
)
if
is_fp8
:
output_data
[
"normal_dispatch"
]
=
deepcopy
(
best_results
[
3
])
if
isinstance
(
current_x
,
tuple
):
# Gather FP8 the best config from rank 0
best_dispatch_results
=
torch
.
tensor
(
[
best_results
[
0
],
best_results
[
1
],
best_results
[
2
]],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
all_best_fp8_results_list
=
[
torch
.
zeros_like
(
best_dispatch_results
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())
]
dist
.
all_gather
(
all_best_fp8_results_list
,
best_dispatch_results
,
group
=
group
)
best_dispatch_results
=
all_best_fp8_results_list
[
0
].
tolist
()
dispatch_config
=
deep_ep
.
Config
(
best_dispatch_results
[
0
],
best_dispatch_results
[
1
],
nvl_buffer_size
,
best_dispatch_results
[
2
],
rdma_buffer_size
,
)
dispatch_args
=
{
"x"
:
x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
dispatch_config
if
dispatch_config
is
not
None
else
config
,
}
recv_x
,
_
,
_
,
_
,
handle
,
_
=
buffer
.
dispatch
(
**
dispatch_args
)
# Tune combine performance
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
5
,
1
):
for
rdma_chunk_size
in
range
(
8
,
33
,
4
):
config_kwargs
=
{
"num_sms"
:
num_sms
,
"num_max_nvl_chunked_send_tokens"
:
nvl_chunk_size
,
"num_max_nvl_chunked_recv_tokens"
:
nvl_buffer_size
,
"num_max_rdma_chunked_send_tokens"
:
rdma_chunk_size
,
"num_max_rdma_chunked_recv_tokens"
:
rdma_buffer_size
,
}
config
=
deep_ep
.
Config
(
**
config_kwargs
)
tune_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
config_kwargs
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] Best combine: SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)"
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
output_data
[
"normal_combine"
]
=
deepcopy
(
best_results
[
3
])
if
rank
==
0
and
local_rank
==
0
:
_write_output
(
args
,
output_data
)
def
_write_output
(
args
,
output_data
):
text
=
json
.
dumps
(
output_data
,
indent
=
4
)
output_path
=
args
.
output_path
print
(
f
"Write to
{
output_path
}
with
{
text
}
"
)
Path
(
output_path
).
write_text
(
text
)
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
num_nodes
=
args
.
nnodes
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
,
args
)
num_sms
=
args
.
num_sms
num_qps_per_rank
=
num_sms
//
2
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
,
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
for
i
in
(
num_sms
,):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
,
args
,
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-sms"
,
type
=
int
,
default
=
24
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
"deepep_tuned.json"
)
parser
.
add_argument
(
"--nnodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--master-addr"
,
type
=
str
,
default
=
"127.0.0.1"
)
parser
.
add_argument
(
"--master-port"
,
type
=
int
,
default
=
8361
)
args
=
parser
.
parse_args
()
print
(
f
"Start system with
{
args
=
}
"
)
num_processes
=
8
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment