Commit d79b3cd1 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Refactor the bench function

parent 85793dda
...@@ -144,18 +144,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -144,18 +144,15 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
bench_output = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, duplicate_name_period=2 if return_recv_hook else None) suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook: if not return_recv_hook:
dispatch_t, combine_t = bench_output
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
else: else:
dispatch_t, combine_t, detail_times = bench_output print(f'[rank {rank}] Dispatch send/recv time: {sum(dispatch_t) * 2 * 1e6:.2f} = {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} = {detail_times["dispatch"][0] * 1e6:.2f} + {detail_times["dispatch"][1] * 1e6:.2f} us | ' f'Combine send/recv time: {sum(combine_t) * 2 * 1e6:.2f} = {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} = {detail_times["combine"][0] * 1e6:.2f} + {detail_times["combine"][1] * 1e6:.2f} us', flush=True)
return hash_value return hash_value
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import sys import sys
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional from typing import Optional, Union
def init_dist(local_rank: int, num_local_ranks: int): def init_dist(local_rank: int, num_local_ranks: int):
...@@ -154,9 +154,9 @@ class suppress_stdout_stderr: ...@@ -154,9 +154,9 @@ class suppress_stdout_stderr:
self.errnull_file.close() self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False, trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
duplicate_name_period: Optional[int] = None): num_kernels_per_period: int = 1):
# Profile # Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress(): with suppress():
...@@ -175,7 +175,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: ...@@ -175,7 +175,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Parse the profiling table # Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names]) assert all([isinstance(name, str) for name in kernel_names])
...@@ -199,29 +199,24 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: ...@@ -199,29 +199,24 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
break break
break break
if duplicate_name_period is None: # Expand the kernels by periods
return tuple(kernel_times) if is_tupled else kernel_times[0] if num_kernels_per_period > 1:
else: with tempfile.NamedTemporaryFile(suffix='.json') as tmp:
detail_times = extract_detail_times_from_prof(prof, kernel_names=kernel_names, duplicate_name_period=duplicate_name_period) prof.export_chrome_trace(tmp.name)
return tuple(kernel_times) + (detail_times,) profile_data = json.loads(Path(tmp.name).read_text())
for i, kernel_name in enumerate(kernel_names):
def extract_detail_times_from_prof(prof, kernel_names, duplicate_name_period: int): events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']]
with tempfile.NamedTemporaryFile(suffix=".json") as tmp: events = sorted(events, key=lambda event: event['ts'])
prof.export_chrome_trace(tmp.name) durations = [event['dur'] / 1e6 for event in events]
profile_data = json.loads(Path(tmp.name).read_text()) assert len(durations) % num_kernels_per_period == 0
num_kernel_patterns = len(durations) // num_kernels_per_period
ans = {} kernel_times[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns
for kernel_name in kernel_names: for j in range(num_kernels_per_period)]
name_matcher = f'::{kernel_name}<'
events = [e for e in profile_data["traceEvents"] if name_matcher in e["name"]] # Return execution times
events = sorted(events, key=lambda e: e["ts"]) return kernel_times if is_tuple else kernel_times[0]
durations = [e["dur"] / 1e6 for e in events]
ans[kernel_name] = [list_mean(durations[i::duplicate_name_period]) for i in range(duplicate_name_period)]
return ans
def list_mean(xs):
return sum(xs) / len(xs)
def hash_tensor(t: torch.Tensor): def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item() return t.view(torch.int64).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment