Unverified Commit 85793dda authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support displaying separate send and recv time (#239)

* more

* more

* more

* more

* more

* more
parent 77ddb015
...@@ -144,15 +144,17 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -144,15 +144,17 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook), bench_output = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True) suppress_kineto_output=True, duplicate_name_period=2 if return_recv_hook else None)
if not return_recv_hook: if not return_recv_hook:
dispatch_t, combine_t = bench_output
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | ' dispatch_t, combine_t, detail_times = bench_output
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us', flush=True) print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} = {detail_times["dispatch"][0] * 1e6:.2f} + {detail_times["dispatch"][1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} = {detail_times["combine"][0] * 1e6:.2f} + {detail_times["combine"][1] * 1e6:.2f} us', flush=True)
return hash_value return hash_value
......
import inspect import inspect
import json
import tempfile
from pathlib import Path
import numpy as np import numpy as np
import os import os
import sys import sys
...@@ -151,7 +155,8 @@ class suppress_stdout_stderr: ...@@ -151,7 +155,8 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False): trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
duplicate_name_period: Optional[int] = None):
# Profile # Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress(): with suppress():
...@@ -193,8 +198,30 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: ...@@ -193,8 +198,30 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
kernel_times.append(float(time_str.replace(unit, '')) / scale) kernel_times.append(float(time_str.replace(unit, '')) / scale)
break break
break break
return tuple(kernel_times) if is_tupled else kernel_times[0]
if duplicate_name_period is None:
return tuple(kernel_times) if is_tupled else kernel_times[0]
else:
detail_times = extract_detail_times_from_prof(prof, kernel_names=kernel_names, duplicate_name_period=duplicate_name_period)
return tuple(kernel_times) + (detail_times,)
def extract_detail_times_from_prof(prof, kernel_names, duplicate_name_period: int):
with tempfile.NamedTemporaryFile(suffix=".json") as tmp:
prof.export_chrome_trace(tmp.name)
profile_data = json.loads(Path(tmp.name).read_text())
ans = {}
for kernel_name in kernel_names:
name_matcher = f'::{kernel_name}<'
events = [e for e in profile_data["traceEvents"] if name_matcher in e["name"]]
events = sorted(events, key=lambda e: e["ts"])
durations = [e["dur"] / 1e6 for e in events]
ans[kernel_name] = [list_mean(durations[i::duplicate_name_period]) for i in range(duplicate_name_period)]
return ans
def list_mean(xs):
return sum(xs) / len(xs)
def hash_tensor(t: torch.Tensor): def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item() return t.view(torch.int64).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment