Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
d79b3cd1
"serialization/vscode:/vscode.git/clone" did not exist on "f6a8d1ea3e062460d03fc5b0592fc6f198c32c2f"
Commit
d79b3cd1
authored
Jul 02, 2025
by
Chenggang Zhao
Browse files
Refactor the bench function
parent
85793dda
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
34 deletions
+26
-34
tests/test_low_latency.py
tests/test_low_latency.py
+4
-7
tests/utils.py
tests/utils.py
+22
-27
No files found.
tests/test_low_latency.py
View file @
d79b3cd1
...
@@ -144,18 +144,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...
@@ -144,18 +144,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Separate profiling
# Separate profiling
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
group
.
barrier
()
group
.
barrier
()
bench_outpu
t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
dispatch_t
,
combine_
t
=
bench_kineto
(
partial
(
test_func
,
zero_copy
=
True
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
duplicate_name
_period
=
2
if
return_recv_hook
else
None
)
suppress_kineto_output
=
True
,
num_kernels_per
_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
if
not
return_recv_hook
:
dispatch_t
,
combine_t
=
bench_output
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
else
:
dispatch_t
,
combine_t
,
detail_times
=
bench_output
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
sum
(
dispatch_t
)
*
2
*
1e6
:.
2
f
}
=
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
*
2
*
1e6
:.
2
f
}
=
{
detail_times
[
"dispatch"
][
0
]
*
1e6
:.
2
f
}
+
{
detail_times
[
"dispatch"
][
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
sum
(
combine_t
)
*
2
*
1e6
:.
2
f
}
=
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
f
'Combine send/recv time:
{
combine_t
*
2
*
1e6
:.
2
f
}
=
{
detail_times
[
"combine"
][
0
]
*
1e6
:.
2
f
}
+
{
detail_times
[
"combine"
][
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
return
hash_value
...
...
tests/utils.py
View file @
d79b3cd1
...
@@ -8,7 +8,7 @@ import os
...
@@ -8,7 +8,7 @@ import os
import
sys
import
sys
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
typing
import
Optional
from
typing
import
Optional
,
Union
def
init_dist
(
local_rank
:
int
,
num_local_ranks
:
int
):
def
init_dist
(
local_rank
:
int
,
num_local_ranks
:
int
):
...
@@ -154,9 +154,9 @@ class suppress_stdout_stderr:
...
@@ -154,9 +154,9 @@ class suppress_stdout_stderr:
self
.
errnull_file
.
close
()
self
.
errnull_file
.
close
()
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
def
bench_kineto
(
fn
,
kernel_names
:
Union
[
str
,
tuple
]
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
Optional
[
str
]
=
None
,
barrier_comm_profiling
:
bool
=
False
,
trace_path
:
Optional
[
str
]
=
None
,
barrier_comm_profiling
:
bool
=
False
,
duplicate_name_period
:
Optional
[
int
]
=
None
):
num_kernels_per_period
:
int
=
1
):
# Profile
# Profile
suppress
=
suppress_stdout_stderr
if
suppress_kineto_output
else
empty_suppress
suppress
=
suppress_stdout_stderr
if
suppress_kineto_output
else
empty_suppress
with
suppress
():
with
suppress
():
...
@@ -175,7 +175,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
...
@@ -175,7 +175,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Parse the profiling table
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tuple
d
=
isinstance
(
kernel_names
,
tuple
)
is_tuple
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
prof
.
key_averages
().
table
(
sort_by
=
'cuda_time_total'
,
max_name_column_width
=
100
).
split
(
'
\n
'
)
prof_lines
=
prof
.
key_averages
().
table
(
sort_by
=
'cuda_time_total'
,
max_name_column_width
=
100
).
split
(
'
\n
'
)
kernel_names
=
(
kernel_names
,
)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
kernel_names
=
(
kernel_names
,
)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
...
@@ -199,29 +199,24 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
...
@@ -199,29 +199,24 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
break
break
break
break
if
duplicate_name_period
is
None
:
# Expand the kernels by periods
return
tuple
(
kernel_times
)
if
is_tupled
else
kernel_times
[
0
]
if
num_kernels_per_period
>
1
:
else
:
with
tempfile
.
NamedTemporaryFile
(
suffix
=
'.json'
)
as
tmp
:
detail_times
=
extract_detail_times_from_prof
(
prof
,
kernel_names
=
kernel_names
,
duplicate_name_period
=
duplicate_name_period
)
return
tuple
(
kernel_times
)
+
(
detail_times
,)
def
extract_detail_times_from_prof
(
prof
,
kernel_names
,
duplicate_name_period
:
int
):
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".json"
)
as
tmp
:
prof
.
export_chrome_trace
(
tmp
.
name
)
prof
.
export_chrome_trace
(
tmp
.
name
)
profile_data
=
json
.
loads
(
Path
(
tmp
.
name
).
read_text
())
profile_data
=
json
.
loads
(
Path
(
tmp
.
name
).
read_text
())
ans
=
{}
for
i
,
kernel_name
in
enumerate
(
kernel_names
):
for
kernel_name
in
kernel_names
:
events
=
[
event
for
event
in
profile_data
[
'traceEvents'
]
if
f
'::
{
kernel_name
}
'
in
event
[
'name'
]]
name_matcher
=
f
'::
{
kernel_name
}
<'
events
=
sorted
(
events
,
key
=
lambda
event
:
event
[
'ts'
])
events
=
[
e
for
e
in
profile_data
[
"traceEvents"
]
if
name_matcher
in
e
[
"name"
]]
durations
=
[
event
[
'dur'
]
/
1e6
for
event
in
events
]
events
=
sorted
(
events
,
key
=
lambda
e
:
e
[
"ts"
])
assert
len
(
durations
)
%
num_kernels_per_period
==
0
durations
=
[
e
[
"dur"
]
/
1e6
for
e
in
events
]
num_kernel_patterns
=
len
(
durations
)
//
num_kernels_per_period
ans
[
kernel_name
]
=
[
list_mean
(
durations
[
i
::
duplicate_name_period
])
for
i
in
range
(
duplicate_name_period
)]
kernel_times
[
i
]
=
[
sum
(
durations
[
j
::
num_kernels_per_period
])
/
num_kernel_patterns
return
ans
for
j
in
range
(
num_kernels_per_period
)]
def
list_mean
(
xs
):
# Return execution times
return
sum
(
xs
)
/
len
(
xs
)
return
kernel_times
if
is_tuple
else
kernel_times
[
0
]
def
hash_tensor
(
t
:
torch
.
Tensor
):
def
hash_tensor
(
t
:
torch
.
Tensor
):
return
t
.
view
(
torch
.
int64
).
sum
().
item
()
return
t
.
view
(
torch
.
int64
).
sum
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment