Commit d79b3cd1 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Refactor the bench function

parent 85793dda
......@@ -144,18 +144,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
bench_output = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, duplicate_name_period=2 if return_recv_hook else None)
suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
dispatch_t, combine_t = bench_output
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
else:
dispatch_t, combine_t, detail_times = bench_output
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} = {detail_times["dispatch"][0] * 1e6:.2f} + {detail_times["dispatch"][1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} = {detail_times["combine"][0] * 1e6:.2f} + {detail_times["combine"][1] * 1e6:.2f} us', flush=True)
print(f'[rank {rank}] Dispatch send/recv time: {sum(dispatch_t) * 2 * 1e6:.2f} = {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {sum(combine_t) * 2 * 1e6:.2f} = {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
return hash_value
......
......@@ -8,7 +8,7 @@ import os
import sys
import torch
import torch.distributed as dist
from typing import Optional
from typing import Optional, Union
def init_dist(local_rank: int, num_local_ranks: int):
......@@ -154,9 +154,9 @@ class suppress_stdout_stderr:
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
duplicate_name_period: Optional[int] = None):
num_kernels_per_period: int = 1):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
......@@ -175,7 +175,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
......@@ -199,29 +199,24 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
break
break
if duplicate_name_period is None:
return tuple(kernel_times) if is_tupled else kernel_times[0]
else:
detail_times = extract_detail_times_from_prof(prof, kernel_names=kernel_names, duplicate_name_period=duplicate_name_period)
return tuple(kernel_times) + (detail_times,)
def extract_detail_times_from_prof(prof, kernel_names, duplicate_name_period: int):
with tempfile.NamedTemporaryFile(suffix=".json") as tmp:
# Expand the kernels by periods
if num_kernels_per_period > 1:
with tempfile.NamedTemporaryFile(suffix='.json') as tmp:
prof.export_chrome_trace(tmp.name)
profile_data = json.loads(Path(tmp.name).read_text())
ans = {}
for kernel_name in kernel_names:
name_matcher = f'::{kernel_name}<'
events = [e for e in profile_data["traceEvents"] if name_matcher in e["name"]]
events = sorted(events, key=lambda e: e["ts"])
durations = [e["dur"] / 1e6 for e in events]
ans[kernel_name] = [list_mean(durations[i::duplicate_name_period]) for i in range(duplicate_name_period)]
return ans
def list_mean(xs):
return sum(xs) / len(xs)
for i, kernel_name in enumerate(kernel_names):
events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']]
events = sorted(events, key=lambda event: event['ts'])
durations = [event['dur'] / 1e6 for event in events]
assert len(durations) % num_kernels_per_period == 0
num_kernel_patterns = len(durations) // num_kernels_per_period
kernel_times[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns
for j in range(num_kernels_per_period)]
# Return execution times
return kernel_times if is_tuple else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment