Commit 2f64c109 authored by lishen's avatar lishen
Browse files

新增采用mpirun启动的测试方式

parent f5364874
for para in $*
do
if [[ $para == --testmode* ]];then
testmode=${para#*=}
elif [[ $para == --profiling* ]];then
profiling=${para#*=}
fi
done
CURRENT_DIR=$( cd "$( dirname "$0" )" && pwd )
echo "CURRENT_DIR = ${CURRENT_DIR}"
TEST_DIR=${CURRENT_DIR}/tests_mpi
LAUNCH_WITH_BINDING=${TEST_DIR}/launch_with_binding.sh # Please adjust the variables based on the actual NET being used
DTK_ENV="/opt/dtk/env.sh" # where env.sh of dtk
TEST_ENV=${TEST_DIR}/test_env.sh
#######################################################################################
# Those variables no need to modify
# HOSTFILE="hostfile_$(basename "$0" | sed -E 's/^run_(.+)\.sh$/\1/')"
HOSTFILE="${TEST_DIR}/hostfile"
GPUS=$(($(cat ${HOSTFILE}|sort|uniq |wc -l) * 8))
HOST="$(cat ${HOSTFILE} |sed -n "1p"|awk -F ' ' '{print $1}')"
PORT="1234"
echo "HOST=${HOST}, PORT=${PORT}, GPUS=${GPUS}"
# Runs aibenchmark model
source ${TEST_ENV}
mpirun -np ${GPUS} --hostfile ${HOSTFILE} \
--allow-run-as-root \
bash -c "
source ${DTK_ENV} && \
source ${TEST_ENV} && \
export TEST_DIR=${TEST_DIR} && \
${TEST_DIR}/test_start.sh \
${HOST} \
${PORT} \
--launch_with_binding=${LAUNCH_WITH_BINDING} \
--testmode=${testmode} \
--profiling=${profiling}
"
#> log-$((${GPUS} / 8))nodes-`date +%F-%H%M`.log 2>&1
wait
## 修改 hostfile,指定测试的模式
# ./run_test.sh --testmode=intranode
# ./run_test.sh --testmode=internode
# ./run_test.sh --testmode=lowlatency
node037 slots=8
node038 slots=8
#!/bin/bash
# wz
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
numa_map=(0 1 2 3 4 5 6 7)
# # 508 天龙
# export HIP_VISIBLE_DEVICES=0,1,2,3,5,4,7,6
# numa_map=(0 1 2 3 5 4 7 6)
# 508 mlxn
# export HIP_VISIBLE_DEVICES=0,1,2,3,5,4,7,6
# corenum=`cat /proc/cpuinfo| grep "cpu cores" | uniq | awk '{print $4}'`
# if [ "${corenum}" -eq "64" ]
# then
# numa_map=(0 3 2 1 7 4 5 6)
# else
# numa_map=(0 2 3 5 8 6 11 9)
# fi
LOCAL_RANK=$1
shift
NUMA_ID=${numa_map[$LOCAL_RANK]}
if [ "$LOCAL_RANK" -eq 0 ]; then
echo "numa_map: ${numa_map}"
fi
numactl --cpunodebind=${NUMA_ID} --membind=${NUMA_ID} "$@"
#!/bin/bash
# =============================================================================
# DeepEP + RCCL/NCCL 环境配置
# 适用于: mpirun 启动的多节点训练
# 网络: InfiniBand (SHCA) 或 RoCE
# =============================================================================
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_TOPO_FILE_FORCE=$(pwd)/tests_mpi/topo.config
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据硬件拓扑调整
export PYTHONPATH=$(pwd)
This diff is collapsed.
import argparse
import os
import time
import torch
import torch.distributed as dist
import socket
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_pg_back
# Test compatibility with low latency functions
import test_low_latency
# noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
assert num_experts % num_ranks == 0
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
# topk_idx = topk_idx.to(deep_ep.topk_idx_t)
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx = rank_idx.to(torch.int64)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print('', flush=True)
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
recv_topk_weights_clone = None
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, rank_prefix_matrix)
# Test `num_worst_tokens != 0`
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_pg_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
combine_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print('', flush=True)
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True)
# Gather the best config from rank 0 and the first test setting
if best_dispatch_results is None:
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_combine_config(num_ranks)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True)
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(args.rank, args.world_size, args.local_rank, args.dist_url)
num_nodes = args.world_size // args.num_processes
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
print(f"rank={rank}, num_ranks={num_ranks}, num_nodes={num_nodes}, ip={ip}")
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank)
for i in (48, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test intranode EP kernels')
group = parser.add_argument_group(title='extra distributed args')
group.add_argument('--rank', default=-int(os.getenv('OMPI_COMM_WORLD_RANK', '0')), type=int,
help='node rank for distributed training')
group.add_argument('--world-size', type=int, default=int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')),
help='number of nodes for distributed training')
group.add_argument('--local-rank', type=int, default=int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--dist-url',
help='Which master node url for distributed training.')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096,
help='Number of tokens (default: 4096)')
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: 256)')
args = parser.parse_args()
if args.world_size <= args.num_processes:
test_loop(args.local_rank, args.num_processes, args)
import argparse
import os
import random
import torch
import torch.distributed as dist
import socket
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
# Simulates rank failure when the rank first calls the corresponding communication API
failed_api_ranks = {
# API -> rank to fail (rank fails when it first calls the corresponding communication API)
'dispatch': 1,
'combine': 3,
'clean': 5
}
if rank in expected_masked_ranks:
# Rank already failed
return True
if api in failed_api_ranks.keys():
expected_masked_ranks.add(failed_api_ranks[api])
if failed_api_ranks[api] == rank:
print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True)
return True
return False
def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor,
expected_masked_ranks: Set[int]):
buffer.low_latency_query_mask_buffer(mask_status)
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
for _ in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
expected_masked_ranks = set()
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
if not dispatch_use_quant:
simulated_gemm_x = packed_recv_x.clone()
elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
elif quant_group_size == 128:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if quant_group_size != 0:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
if rank == 0:
print('', flush=True)
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us',
flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(args.rank, args.world_size, args.local_rank, args.dist_url)
num_nodes = args.world_size // args.num_processes
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
print(f"rank={rank}, num_ranks={num_ranks}, num_nodes={num_nodes}, ip={ip}")
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
print(f"num_tokens, hidden, num_ranks, num_experts = {num_tokens}, {hidden}, {num_ranks}, {num_experts}")
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=1)
do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0):
if rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
group = parser.add_argument_group(title='extra distributed args')
group.add_argument('--rank', default=-int(os.getenv('OMPI_COMM_WORLD_RANK', '0')), type=int,
help='node rank for distributed training')
group.add_argument('--world-size', type=int, default=int(os.getenv('OMPI_COMM_WORLD_SIZE', '0')),
help='number of nodes for distributed training')
group.add_argument('--local-rank', type=int, default=int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--dist-url',
help='Which master node url for distributed training.')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
if args.world_size > args.num_processes:
test_loop(args.local_rank, args.num_processes, args)
#!/bin/bash
for para in $*
do
if [[ $para == --launch_with_binding* ]];then
launch_with_binding=${para#*=}
elif [[ $para == --testmode* ]];then
testmode=${para#*=}
elif [[ $para == --profiling* ]];then
profiling=${para#*=}
fi
done
# default env
DIST_URL=${1}
DIST_PORT=${2}
RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
# =============================================================================
# 调试输出(确认环境变量传递正确)
# =============================================================================
if [ "$RANK" -eq 0 ]; then
echo "=== DeepEP Test Start ==="
echo "Test mode: ${testmode:-internode}"
echo "World size: $WORLD_SIZE"
echo "Master: $DIST_URL:$DIST_PORT"
echo "PYTHONPATH: $PYTHONPATH"
echo "TEST_DIR: $TEST_DIR"
echo "ROCSHMEM_TOPO_FILE_FORCE: $ROCSHMEM_TOPO_FILE_FORCE"
echo "NCCL_PLUGIN: ${NCCL_NET_PLUGIN:-none}"
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-auto}"
echo "HSA_FORCE_FINE_GRAIN_PCIE: ${HSA_FORCE_FINE_GRAIN_PCIE:-not set}"
fi
DISTRIBUTED_ARGS=(
--rank ${RANK}
--world-size ${WORLD_SIZE}
--local-rank ${LOCAL_RANK}
--dist-url tcp://${DIST_URL}:${DIST_PORT}
)
TEST_BASE_ARGS=(
--hidden 7168
--num-experts 256
--num-topk 8
)
# 三种模式的 APP 定义
case ${testmode} in
intranode)
# 节点内测试
INTRANODE_ARGS=(
"${TEST_BASE_ARGS[@]}"
# intranode 特定参数:
--num-tokens 4096
)
APP="python3 -u ${TEST_DIR}/test_intranode.py \
${DISTRIBUTED_ARGS[@]} \
${INTRANODE_ARGS[@]} \
"
;;
lowlatency)
# 低延迟测试
LOWLATENCY_ARGS=(
"${TEST_BASE_ARGS[@]}"
# lowlatency 特定参数:
--num-tokens 128
# --pressure-test
)
APP="python3 -u ${TEST_DIR}/test_low_latency.py \
${DISTRIBUTED_ARGS[@]} \
${LOWLATENCY_ARGS[@]} \
"
;;
internode|*)
# 跨节点测试(默认)
INTERNODE_ARGS=(
"${TEST_BASE_ARGS[@]}"
# internode 特定参数:
--num-tokens 4096
# --test-ll-compatibility
)
APP="python3 -u ${TEST_DIR}/test_internode.py \
${DISTRIBUTED_ARGS[@]} \
${INTERNODE_ARGS[@]} \
"
;;
esac
###############################################################################
TORCH_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 4 6 8 32
--profile-step-start 3
--profile-step-end 4
--profile-dir torch_prof_aibenchmark_8nodes_tp4-pp2-ep8-etp2-cp1-vp2
--use-pytorch-profiler
)
HIP_PROFIE_ARGS=(
--profile
--profile-ranks 0 1 2 3 4 6 8 32
--profile-step-start 4
--profile-step-end 5
--use-hip-profiler
)
if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS[@]}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS[@]}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi
###############################################################################
echo "launch_with_binding=${launch_with_binding}, APP=${APP}"
#for hygon cpu
${launch_with_binding} ${LOCAL_RANK} ${APP}
0000:9f:00.0 mlx5_2 2
0000:56:00.0 mlx5_3 3
0000:5d:00.0 mlx5_4 4
0000:05:00.0 mlx5_5 5
0000:e5:00.0 mlx5_6 6
0000:c1:00.0 mlx5_7 7
0000:ca:00.0 mlx5_8 8
0000:b1:00.0 mlx5_9 9
\ No newline at end of file
import inspect
import json
import tempfile
from pathlib import Path
import numpy as np
import os
import sys
import torch
import torch.distributed as dist
from typing import Optional, Union
def init_dist(rank: int, world_size: int, local_rank: int, dist_url: str):
# NOTES: you may rewrite this function with your own cluster settings
sig = inspect.signature(dist.init_process_group)
params = {
'backend': 'nccl',
'init_method': f'{dist_url}',
'world_size': world_size,
'rank': rank,
}
if 'device_id' in sig.parameters:
# noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(world_size)))
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def align_up(x, y):
return (x + y - 1) // y * y
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2
m, n = x.shape
aligned_n = align_up(n, 128)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
x_padded_view = x_padded.view(m, -1, 128)
x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
if x.numel() == 0:
return x.to(torch.bfloat16)
assert x.dim() == 2
m, n = x.shape
aligned_n = align_up(n, 128)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32_padded = x_padded.to(torch.float32).view(x.size(0), -1, 128)
x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_pc_back(x: torch.Tensor, x_scales: torch.Tensor):
if x.numel() == 0:
return x.to(torch.bfloat16)
assert x.dim() == 2
m, n = x.shape
aligned_n = align_up(n, 128)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
x_fp32_padded = x_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32)
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous()
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False,
num_kernels_per_period: int = 1):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
fn()
torch.cuda.synchronize()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel durations
units = {'ms': 1e3, 'us': 1e6}
kernel_durations = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_durations.append(float(time_str.replace(unit, '')) / scale)
break
break
# Expand the kernels by periods
if num_kernels_per_period > 1:
with tempfile.NamedTemporaryFile(suffix='.json') as tmp:
prof.export_chrome_trace(tmp.name)
profile_data = json.loads(Path(tmp.name).read_text())
for i, kernel_name in enumerate(kernel_names):
events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']]
events = sorted(events, key=lambda event: event['ts'])
durations = [event['dur'] / 1e6 for event in events]
assert len(durations) % num_kernels_per_period == 0
num_kernel_patterns = len(durations) // num_kernels_per_period
kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns
for j in range(num_kernels_per_period)]
# Return execution durations
return kernel_durations if is_tuple else kernel_durations[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment