Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6df81e8a
Unverified
Commit
6df81e8a
authored
May 29, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Support tuning DeepEP configs (#6742)
parent
3ab7d9b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
694 additions
and
0 deletions
+694
-0
benchmark/kernels/deepep/deepep_utils.py
benchmark/kernels/deepep/deepep_utils.py
+218
-0
benchmark/kernels/deepep/tuning_deepep.py
benchmark/kernels/deepep/tuning_deepep.py
+476
-0
No files found.
benchmark/kernels/deepep/deepep_utils.py
0 → 100644
View file @
6df81e8a
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
import
os
import
sys
from
typing
import
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
def
init_dist
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
ip
=
args
.
master_addr
port
=
args
.
master_port
num_nodes
=
args
.
nnodes
node_rank
=
args
.
node_rank
assert
(
num_local_ranks
<
8
and
num_nodes
==
1
)
or
num_local_ranks
==
8
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
,
world_size
=
num_nodes
*
num_local_ranks
,
rank
=
node_rank
*
num_local_ranks
+
local_rank
,
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
set_device
(
local_rank
)
return
(
dist
.
get_rank
(),
dist
.
get_world_size
(),
dist
.
new_group
(
list
(
range
(
num_local_ranks
*
num_nodes
))),
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
()
+
1
,
y
.
double
()
+
1
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
(
1
-
sim
).
item
()
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
x_fp32
=
x_fp8
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32
*
x_scales
).
view
(
x_fp8
.
shape
).
to
(
torch
.
bfloat16
)
def
inplace_unique
(
x
:
torch
.
Tensor
,
num_slots
:
int
):
assert
x
.
dim
()
==
2
mask
=
x
<
0
x_padded
=
x
.
masked_fill
(
mask
,
num_slots
)
bin_count
=
torch
.
zeros
((
x
.
size
(
0
),
num_slots
+
1
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
bin_count
.
scatter_add_
(
1
,
x_padded
,
torch
.
ones_like
(
x_padded
))
bin_count
=
bin_count
[:,
:
num_slots
]
sorted_bin_count
,
sorted_bin_idx
=
torch
.
sort
(
bin_count
,
dim
=-
1
,
descending
=
True
)
sorted_bin_idx
.
masked_fill_
(
sorted_bin_count
==
0
,
-
1
)
sorted_bin_idx
=
torch
.
sort
(
sorted_bin_idx
,
descending
=
True
,
dim
=-
1
).
values
x
[:,
:].
fill_
(
-
1
)
valid_len
=
min
(
num_slots
,
x
.
size
(
1
))
x
[:,
:
valid_len
]
=
sorted_bin_idx
[:,
:
valid_len
]
def
create_grouped_scores
(
scores
:
torch
.
Tensor
,
group_idx
:
torch
.
Tensor
,
num_groups
:
int
):
num_tokens
,
num_experts
=
scores
.
shape
scores
=
scores
.
view
(
num_tokens
,
num_groups
,
-
1
)
mask
=
torch
.
zeros
((
num_tokens
,
num_groups
),
dtype
=
torch
.
bool
,
device
=
scores
.
device
)
mask
=
mask
.
scatter_
(
1
,
group_idx
,
True
).
unsqueeze
(
-
1
).
expand_as
(
scores
)
return
(
scores
*
mask
).
view
(
num_tokens
,
num_experts
)
def
bench
(
fn
,
num_warmups
:
int
=
20
,
num_tests
:
int
=
30
,
post_fn
=
None
):
# Flush L2 cache with 256 MB data
torch
.
cuda
.
synchronize
()
cache
=
torch
.
empty
(
int
(
256e6
//
4
),
dtype
=
torch
.
int
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
num_warmups
):
fn
()
# Flush L2
cache
.
zero_
()
# Testing
start_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
end_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
num_tests
)]
for
i
in
range
(
num_tests
):
# Record
start_events
[
i
].
record
()
fn
()
end_events
[
i
].
record
()
if
post_fn
is
not
None
:
post_fn
()
torch
.
cuda
.
synchronize
()
times
=
np
.
array
(
[
s
.
elapsed_time
(
e
)
/
1e3
for
s
,
e
in
zip
(
start_events
,
end_events
)]
)[
1
:]
return
np
.
average
(
times
),
np
.
min
(
times
),
np
.
max
(
times
)
class
empty_suppress
:
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
_
):
pass
class
suppress_stdout_stderr
:
def
__enter__
(
self
):
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
self
.
old_stderr_fileno_undup
=
sys
.
stderr
.
fileno
()
self
.
old_stdout_fileno
=
os
.
dup
(
sys
.
stdout
.
fileno
())
self
.
old_stderr_fileno
=
os
.
dup
(
sys
.
stderr
.
fileno
())
self
.
old_stdout
=
sys
.
stdout
self
.
old_stderr
=
sys
.
stderr
os
.
dup2
(
self
.
outnull_file
.
fileno
(),
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
errnull_file
.
fileno
(),
self
.
old_stderr_fileno_undup
)
sys
.
stdout
=
self
.
outnull_file
sys
.
stderr
=
self
.
errnull_file
return
self
def
__exit__
(
self
,
*
_
):
sys
.
stdout
=
self
.
old_stdout
sys
.
stderr
=
self
.
old_stderr
os
.
dup2
(
self
.
old_stdout_fileno
,
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
old_stderr_fileno
,
self
.
old_stderr_fileno_undup
)
os
.
close
(
self
.
old_stdout_fileno
)
os
.
close
(
self
.
old_stderr_fileno
)
self
.
outnull_file
.
close
()
self
.
errnull_file
.
close
()
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
Optional
[
str
]
=
None
,
barrier_comm_profiling
:
bool
=
False
,
):
# Profile
suppress
=
suppress_stdout_stderr
if
suppress_kineto_output
else
empty_suppress
with
suppress
():
schedule
=
torch
.
profiler
.
schedule
(
wait
=
0
,
warmup
=
1
,
active
=
1
,
repeat
=
1
)
with
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
],
schedule
=
schedule
)
as
prof
:
for
i
in
range
(
2
):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if
barrier_comm_profiling
:
lhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
rhs
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
,
device
=
"cuda"
)
lhs
@
rhs
dist
.
all_reduce
(
torch
.
ones
(
1
,
dtype
=
torch
.
float
,
device
=
"cuda"
))
for
_
in
range
(
num_tests
):
fn
()
prof
.
step
()
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tupled
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
(
prof
.
key_averages
()
.
table
(
sort_by
=
"cuda_time_total"
,
max_name_column_width
=
100
)
.
split
(
"
\n
"
)
)
kernel_names
=
(
kernel_names
,)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table"
# Save chrome traces
if
trace_path
is
not
None
:
prof
.
export_chrome_trace
(
trace_path
)
# Return average kernel times
units
=
{
"ms"
:
1e3
,
"us"
:
1e6
}
kernel_times
=
[]
for
name
in
kernel_names
:
for
line
in
prof_lines
:
if
name
in
line
:
time_str
=
line
.
split
()[
-
2
]
for
unit
,
scale
in
units
.
items
():
if
unit
in
time_str
:
kernel_times
.
append
(
float
(
time_str
.
replace
(
unit
,
""
))
/
scale
)
break
break
return
tuple
(
kernel_times
)
if
is_tupled
else
kernel_times
[
0
]
def
hash_tensor
(
t
:
torch
.
Tensor
):
return
t
.
view
(
torch
.
int64
).
sum
().
item
()
benchmark/kernels/deepep/tuning_deepep.py
0 → 100644
View file @
6df81e8a
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
"""
Example usage:
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
Then check `deepep_tuned.json`
"""
import
argparse
import
json
import
time
from
copy
import
deepcopy
from
pathlib
import
Path
# noinspection PyUnresolvedReferences
import
deep_ep
import
torch
import
torch.distributed
as
dist
from
deepep_utils
import
(
bench
,
calc_diff
,
create_grouped_scores
,
init_dist
,
inplace_unique
,
per_token_cast_back
,
per_token_cast_to_fp8
,
)
def
test_main
(
num_sms
:
int
,
local_rank
:
int
,
num_local_ranks
:
int
,
num_ranks
:
int
,
num_nodes
:
int
,
rank
:
int
,
buffer
:
deep_ep
.
Buffer
,
group
:
dist
.
ProcessGroup
,
args
,
):
# Settings
num_tokens
,
hidden
,
num_topk_groups
,
num_topk
,
num_experts
=
(
4096
,
7168
,
min
(
num_nodes
,
4
),
8
,
(
256
//
num_ranks
)
*
num_ranks
,
)
assert
num_experts
%
num_ranks
==
0
and
num_local_ranks
==
8
if
local_rank
==
0
:
print
(
f
"[config] num_tokens=
{
num_tokens
}
, hidden=
{
hidden
}
, num_topk_groups=
{
num_topk_groups
}
, num_topk=
{
num_topk
}
"
,
flush
=
True
,
)
# Random data
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
rank
x_pure_rand
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
x_e4m3
=
per_token_cast_to_fp8
(
x
)
scores
=
(
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
"cuda"
).
abs
()
+
1
)
group_scores
=
scores
.
view
(
num_tokens
,
num_nodes
,
-
1
).
amax
(
dim
=-
1
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
num_topk_groups
,
dim
=-
1
,
sorted
=
False
).
indices
masked_scores
=
create_grouped_scores
(
scores
,
group_idx
,
num_nodes
)
topk_idx
=
torch
.
topk
(
masked_scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)[
1
]
topk_weights
=
(
torch
.
ones
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
rank
)
topk_weights_pure_rand
=
torch
.
randn
(
(
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rank_idx
=
topk_idx
//
(
num_experts
//
num_ranks
)
rank_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rank_idx
,
num_ranks
)
rdma_rank_idx
=
rank_idx
//
num_local_ranks
rdma_rank_idx
.
masked_fill_
(
rank_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_rank_idx
,
num_nodes
)
# RDMA dispatch counts
rdma_idx
=
topk_idx
//
(
num_experts
//
num_nodes
)
rdma_idx
.
masked_fill_
(
topk_idx
==
-
1
,
-
1
)
inplace_unique
(
rdma_idx
,
num_nodes
)
num_rdma_token_sent
=
rdma_idx
.
ne
(
-
1
).
sum
().
item
()
# Expert meta
num_tokens_per_expert
=
torch
.
zeros
((
num_experts
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
i
in
range
(
num_experts
):
num_tokens_per_expert
[
i
]
=
(
topk_idx
==
i
).
sum
()
gbl_num_tokens_per_expert
=
num_tokens_per_expert
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_expert
,
group
=
group
)
# Rank layout meta
num_tokens_per_rank
=
torch
.
empty
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
num_tokens_per_rdma_rank
=
torch
.
empty
((
num_nodes
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
token_idx_in_rank
=
torch
.
full
(
(
num_ranks
,
num_tokens
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_ranks
):
num_tokens_per_rank
[
i
]
=
(
rank_idx
==
i
).
sum
()
token_sel
=
(
rank_idx
==
i
).
max
(
dim
=-
1
)[
0
]
count
=
token_sel
.
sum
().
item
()
tokens
=
torch
.
sort
(
token_sel
.
to
(
torch
.
int
),
descending
=
True
)[
1
]
tokens
[:
count
]
=
torch
.
sort
(
tokens
[:
count
])[
0
]
token_idx_in_rank
[
i
][
tokens
[:
count
]]
=
torch
.
arange
(
count
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
for
i
in
range
(
num_nodes
):
num_tokens_per_rdma_rank
[
i
]
=
(
rdma_rank_idx
==
i
).
sum
()
token_idx_in_rank
=
token_idx_in_rank
.
T
.
contiguous
().
to
(
torch
.
int
)
is_token_in_rank
=
token_idx_in_rank
>=
0
gbl_num_tokens_per_rank
=
num_tokens_per_rank
.
clone
()
dist
.
all_reduce
(
gbl_num_tokens_per_rank
,
group
=
group
)
(
ref_num_tokens_per_rank
,
ref_num_tokens_per_rdma_rank
,
ref_num_tokens_per_expert
,
ref_is_token_in_rank
,
_
,
)
=
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
)
assert
torch
.
allclose
(
ref_num_tokens_per_rank
,
num_tokens_per_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
)
assert
torch
.
allclose
(
ref_num_tokens_per_expert
,
num_tokens_per_expert
)
assert
torch
.
allclose
(
ref_is_token_in_rank
,
is_token_in_rank
)
t
=
bench
(
lambda
:
buffer
.
get_dispatch_layout
(
topk_idx
,
num_experts
))[
0
]
if
local_rank
==
0
:
print
(
f
"[layout] Kernel performance:
{
t
*
1000
:.
3
f
}
ms"
,
flush
=
True
)
print
(
""
,
flush
=
True
)
group
.
barrier
()
time
.
sleep
(
1
)
# Config
rdma_buffer_size
,
nvl_buffer_size
=
128
,
(
720
if
num_ranks
in
(
144
,
160
)
else
512
)
config
=
deep_ep
.
Config
(
num_sms
,
8
,
nvl_buffer_size
,
16
,
rdma_buffer_size
)
# Test dispatch
# noinspection PyShadowingNames
def
check_data
(
check_x
,
recv_gbl_rank_prefix_sum
):
assert
torch
.
allclose
(
check_x
.
amin
(
dim
=
1
),
check_x
.
amax
(
dim
=
1
))
check_start
=
0
for
i
in
range
(
num_ranks
):
check_end
=
recv_gbl_rank_prefix_sum
[
i
].
item
()
assert
(
check_x
[
check_start
:
check_end
,
:].
int
()
-
i
).
sum
().
item
()
==
0
check_start
=
check_end
for
previous_mode
in
(
False
,
True
):
for
async_mode
in
(
False
,
True
):
for
current_x
in
(
x_pure_rand
,
x
,
x_e4m3
):
for
with_topk
in
(
False
,
True
):
if
local_rank
==
0
:
print
(
f
'[testing] Running with
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
,
{
"with"
if
with_topk
else
"without"
}
top-k (async=
{
async_mode
}
, previous=
{
previous_mode
}
) ...'
,
flush
=
True
,
end
=
""
,
)
dispatch_args
=
{
"x"
:
current_x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
dispatch_args
.
update
(
{
"topk_idx"
:
topk_idx
,
"topk_weights"
:
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
),
}
)
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
(
recv_x
,
recv_topk_idx
,
recv_topk_weights
,
recv_num_tokens_per_expert_list
,
handle
,
event
,
)
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
# Checks
recv_gbl_rank_prefix_sum
=
handle
[
-
4
]
assert
gbl_num_tokens_per_rank
[
rank
].
item
()
==
recv_x
.
size
(
0
),
f
"
{
gbl_num_tokens_per_rank
[
rank
].
item
()
}
!=
{
recv_x
.
size
(
0
)
}
"
assert
(
gbl_num_tokens_per_expert
.
view
(
num_ranks
,
-
1
)[
rank
].
tolist
()
==
recv_num_tokens_per_expert_list
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
if
with_topk
:
# Check `topk_idx`
assert
(
recv_topk_idx
.
eq
(
-
1
)
|
(
(
recv_topk_idx
>=
0
)
&
(
recv_topk_idx
<
(
num_experts
//
num_ranks
))
)
).
sum
().
item
()
==
recv_topk_idx
.
numel
()
for
i
,
count
in
enumerate
(
recv_num_tokens_per_expert_list
):
assert
recv_topk_idx
.
eq
(
i
).
sum
().
item
()
==
count
# Check `topk_weights`
if
current_x
is
not
x_pure_rand
:
recv_topk_weights
[
recv_topk_idx
.
eq
(
-
1
)]
=
(
recv_topk_weights
.
amax
(
dim
=
1
,
keepdim
=
True
).
expand_as
(
recv_topk_weights
)[
recv_topk_idx
.
eq
(
-
1
)]
)
check_data
(
recv_topk_weights
,
recv_gbl_rank_prefix_sum
)
# Test cached dispatch (must without top-k staffs)
if
not
with_topk
:
dispatch_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
previous_mode
:
dispatch_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
recv_x
,
_
,
_
,
_
,
_
,
event
=
buffer
.
dispatch
(
**
dispatch_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
recv_x
=
(
per_token_cast_back
(
*
recv_x
)
if
isinstance
(
recv_x
,
tuple
)
else
recv_x
)
if
current_x
is
not
x_pure_rand
:
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
# Test combine
combine_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
,
"async_finish"
:
async_mode
,
}
if
with_topk
:
combine_args
.
update
({
"topk_weights"
:
recv_topk_weights
})
if
previous_mode
:
combine_args
.
update
({
"previous_event"
:
buffer
.
capture
()})
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
check_x
=
combined_x
.
float
()
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
if
with_topk
:
check_topk_weights
=
(
combined_topk_weights
if
(
current_x
is
x_pure_rand
)
else
(
combined_topk_weights
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
)
)
ref_topk_weights
=
(
topk_weights_pure_rand
if
current_x
is
x_pure_rand
else
topk_weights
)
assert
calc_diff
(
check_topk_weights
,
ref_topk_weights
)
<
1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes
=
num_rdma_token_sent
*
hidden
*
2
dispatch_bf16_nvl_recv_bytes
=
recv_x
.
numel
()
*
2
combine_bf16_nvl_send_bytes
=
dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes
=
dispatch_bf16_rdma_send_bytes
if
local_rank
==
0
:
print
(
" passed"
,
flush
=
True
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
output_data
=
{}
# Tune dispatch performance
best_dispatch_results
=
None
fp8_factor
=
(
1
+
4
/
128
)
/
2
for
current_x
in
(
x_e4m3
,
x
):
best_time
,
best_results
=
1e10
,
None
rdma_send_bytes
=
(
(
dispatch_bf16_rdma_send_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes
=
(
(
dispatch_bf16_nvl_recv_bytes
*
fp8_factor
)
if
isinstance
(
current_x
,
tuple
)
else
dispatch_bf16_nvl_recv_bytes
)
for
nvl_chunk_size
in
range
(
4
,
33
,
4
):
for
rdma_chunk_size
in
range
(
4
,
33
,
4
):
config_kwargs
=
{
"num_sms"
:
num_sms
,
"num_max_nvl_chunked_send_tokens"
:
nvl_chunk_size
,
"num_max_nvl_chunked_recv_tokens"
:
nvl_buffer_size
,
"num_max_rdma_chunked_send_tokens"
:
rdma_chunk_size
,
"num_max_rdma_chunked_recv_tokens"
:
rdma_buffer_size
,
}
config
=
deep_ep
.
Config
(
**
config_kwargs
)
tune_args
=
{
"x"
:
current_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
dispatch
(
**
tune_args
))[
0
]
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
config_kwargs
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
rdma_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
local_rank
==
0
:
print
(
f
'[tuning] Best dispatch (
{
"FP8"
if
isinstance
(
current_x
,
tuple
)
else
"BF16"
}
): SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
rdma_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
nvl_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)'
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
is_fp8
=
isinstance
(
current_x
,
tuple
)
if
is_fp8
:
output_data
[
"normal_dispatch"
]
=
deepcopy
(
best_results
[
3
])
if
isinstance
(
current_x
,
tuple
):
# Gather FP8 the best config from rank 0
best_dispatch_results
=
torch
.
tensor
(
[
best_results
[
0
],
best_results
[
1
],
best_results
[
2
]],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
all_best_fp8_results_list
=
[
torch
.
zeros_like
(
best_dispatch_results
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())
]
dist
.
all_gather
(
all_best_fp8_results_list
,
best_dispatch_results
,
group
=
group
)
best_dispatch_results
=
all_best_fp8_results_list
[
0
].
tolist
()
dispatch_config
=
deep_ep
.
Config
(
best_dispatch_results
[
0
],
best_dispatch_results
[
1
],
nvl_buffer_size
,
best_dispatch_results
[
2
],
rdma_buffer_size
,
)
dispatch_args
=
{
"x"
:
x
,
"num_tokens_per_rank"
:
num_tokens_per_rank
,
"num_tokens_per_rdma_rank"
:
num_tokens_per_rdma_rank
,
"is_token_in_rank"
:
is_token_in_rank
,
"num_tokens_per_expert"
:
num_tokens_per_expert
,
"config"
:
dispatch_config
if
dispatch_config
is
not
None
else
config
,
}
recv_x
,
_
,
_
,
_
,
handle
,
_
=
buffer
.
dispatch
(
**
dispatch_args
)
# Tune combine performance
best_time
,
best_results
=
1e10
,
None
for
nvl_chunk_size
in
range
(
1
,
5
,
1
):
for
rdma_chunk_size
in
range
(
8
,
33
,
4
):
config_kwargs
=
{
"num_sms"
:
num_sms
,
"num_max_nvl_chunked_send_tokens"
:
nvl_chunk_size
,
"num_max_nvl_chunked_recv_tokens"
:
nvl_buffer_size
,
"num_max_rdma_chunked_send_tokens"
:
rdma_chunk_size
,
"num_max_rdma_chunked_recv_tokens"
:
rdma_buffer_size
,
}
config
=
deep_ep
.
Config
(
**
config_kwargs
)
tune_args
=
{
"x"
:
recv_x
,
"handle"
:
handle
,
"config"
:
config
}
t
=
bench
(
lambda
:
buffer
.
combine
(
**
tune_args
))[
0
]
if
local_rank
==
0
:
print
(
f
"[tuning] SMs
{
num_sms
}
, NVL chunk
{
nvl_chunk_size
}
, RDMA chunk
{
rdma_chunk_size
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
t
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
t
:.
2
f
}
GB/s (NVL) "
,
flush
=
True
,
)
if
t
<
best_time
:
best_time
,
best_results
=
t
,
(
num_sms
,
nvl_chunk_size
,
rdma_chunk_size
,
config_kwargs
,
)
if
local_rank
==
0
:
print
(
f
"[tuning] Best combine: SMs
{
best_results
[
0
]
}
, NVL chunk
{
best_results
[
1
]
}
, RDMA chunk
{
best_results
[
2
]
}
:
{
combine_bf16_rdma_recv_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (RDMA),
{
combine_bf16_nvl_send_bytes
/
1e9
/
best_time
:.
2
f
}
GB/s (NVL)"
,
flush
=
True
,
)
print
(
""
,
flush
=
True
)
output_data
[
"normal_combine"
]
=
deepcopy
(
best_results
[
3
])
if
rank
==
0
and
local_rank
==
0
:
_write_output
(
args
,
output_data
)
def
_write_output
(
args
,
output_data
):
text
=
json
.
dumps
(
output_data
,
indent
=
4
)
output_path
=
args
.
output_path
print
(
f
"Write to
{
output_path
}
with
{
text
}
"
)
Path
(
output_path
).
write_text
(
text
)
# noinspection PyUnboundLocalVariable
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
):
num_nodes
=
args
.
nnodes
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
,
args
)
num_sms
=
args
.
num_sms
num_qps_per_rank
=
num_sms
//
2
buffer
=
deep_ep
.
Buffer
(
group
,
int
(
1e9
),
int
(
1e9
),
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
,
)
assert
num_local_ranks
==
8
and
num_ranks
>
8
torch
.
manual_seed
(
rank
)
for
i
in
(
num_sms
,):
test_main
(
i
,
local_rank
,
num_local_ranks
,
num_ranks
,
num_nodes
,
rank
,
buffer
,
group
,
args
,
)
if
local_rank
==
0
:
print
(
""
,
flush
=
True
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-sms"
,
type
=
int
,
default
=
24
)
parser
.
add_argument
(
"--output-path"
,
type
=
str
,
default
=
"deepep_tuned.json"
)
parser
.
add_argument
(
"--nnodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--master-addr"
,
type
=
str
,
default
=
"127.0.0.1"
)
parser
.
add_argument
(
"--master-port"
,
type
=
int
,
default
=
8361
)
args
=
parser
.
parse_args
()
print
(
f
"Start system with
{
args
=
}
"
)
num_processes
=
8
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment