Unverified Commit ef70b83e authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Add pressure test mode for internode test (#400)

* Suppress kineto output

* Add pressure test mode

* Add `x_pure_rand_e4m3` test

* Add more results into hash value
parent 3e2c5d80
...@@ -6,7 +6,7 @@ import torch.distributed as dist ...@@ -6,7 +6,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor
# Test compatibility with low latency functions # Test compatibility with low latency functions
import test_low_latency import test_low_latency
...@@ -15,7 +15,7 @@ import test_low_latency ...@@ -15,7 +15,7 @@ import test_low_latency
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int, def test_main(args: argparse.Namespace, num_sms: int,
local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup): buffer: deep_ep.Buffer, group: dist.ProcessGroup, skip_benchmark: bool = False):
# Settings # Settings
num_tokens, hidden = args.num_tokens, args.hidden num_tokens, hidden = args.num_tokens, args.hidden
num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
...@@ -28,6 +28,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -28,6 +28,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x) x_e4m3 = per_token_cast_to_fp8(x)
x_pure_rand_e4m3 = per_token_cast_to_fp8(x_pure_rand)
x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
...@@ -42,6 +43,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -42,6 +43,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
rdma_rank_idx = rank_idx // num_local_ranks rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1) rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes) inplace_unique(rdma_rank_idx, num_nodes)
hash_value = 0
# RDMA dispatch counts # RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes) rdma_idx = topk_idx // (num_experts // num_nodes)
...@@ -103,25 +105,33 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -103,25 +105,33 @@ def test_main(args: argparse.Namespace, num_sms: int,
for previous_mode in (False, True): for previous_mode in (False, True):
for async_mode in (False, True): for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3): for current_x in (x_pure_rand, x, x_pure_rand_e4m3, x_e4m3):
for with_topk in (False, True): for with_topk in (False, True):
is_rand = current_x is x_pure_rand or current_x is x_pure_rand_e4m3
if local_rank == 0: if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk: if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if is_rand else topk_weights})
if previous_mode: if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
if current_x is x_pure_rand or current_x is x:
hash_value += hash_tensor(recv_x)
else:
hash_value += hash_tensor(recv_x[0])
hash_value += hash_tensor(recv_x[1])
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks # Checks
recv_gbl_rank_prefix_sum = handle[-4] recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand: if not is_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum) check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk: if with_topk:
# Check `topk_idx` # Check `topk_idx`
...@@ -130,7 +140,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -130,7 +140,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
assert recv_topk_idx.eq(i).sum().item() == count assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights` # Check `topk_weights`
if current_x is not x_pure_rand: if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
...@@ -142,7 +152,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -142,7 +152,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand: if not is_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum) check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine # Test combine
...@@ -156,13 +166,15 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -156,13 +166,15 @@ def test_main(args: argparse.Namespace, num_sms: int,
combined_x, combined_topk_weights, event = buffer.combine(**combine_args) combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x ref_x = x_pure_rand if is_rand else x
assert calc_diff(check_x, ref_x) < 5e-6 assert calc_diff(check_x, ref_x) < 5e-4 if current_x is x_pure_rand_e4m3 else 5e-6
if with_topk: if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) check_topk_weights = combined_topk_weights if is_rand else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights ref_topk_weights = topk_weights_pure_rand if is_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
hash_value += hash_tensor(recv_x)
# For later tuning # For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
...@@ -174,6 +186,9 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -174,6 +186,9 @@ def test_main(args: argparse.Namespace, num_sms: int,
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
if skip_benchmark:
return hash_value
# Tune dispatch performance # Tune dispatch performance
best_dispatch_results = None best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2 fp8_factor = (1 + 4 / 128) / 2
...@@ -185,7 +200,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -185,7 +200,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
for rdma_chunk_size in range(4, 33, 4): for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config} tune_args = {'x': current_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify')) t, notify_t = bench_kineto(lambda: buffer.dispatch(**tune_args), ('dispatch', 'notify'), suppress_kineto_output=True)
if t < best_time: if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t)
if local_rank == 0: if local_rank == 0:
...@@ -213,7 +228,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -213,7 +228,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config} tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify')) t, notify_t = bench_kineto(lambda: buffer.combine(**tune_args), ('combine', 'notify'), suppress_kineto_output=True)
if local_rank == 0: if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time: if t < best_time:
...@@ -222,6 +237,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -222,6 +237,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
if local_rank == 0: if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True) print('', flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames # noinspection PyUnboundLocalVariable,PyShadowingNames
...@@ -237,13 +253,32 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -237,13 +253,32 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True) num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (num_sms, ): for seed in range(int(1e9)):
test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0: if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
torch.manual_seed(rank + seed)
ref_hash = 0
for i in (num_sms, ):
ref_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args.pressure_test_mode == 1)
if local_rank == 0:
print('', flush=True)
if args.pressure_test_mode == 0:
break
if local_rank == 0:
print(f'{ref_hash=}')
print('', flush=True) print('', flush=True)
for j in range(20):
torch.manual_seed(rank + seed)
current_hash = 0
for i in (num_sms, ):
current_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args.pressure_test_mode == 1)
if local_rank == 0:
print('', flush=True)
assert current_hash == ref_hash
# Test compatibility with low latency functions # Test compatibility with low latency functions
if args.test_ll_compatibility: if args.test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
...@@ -267,6 +302,8 @@ if __name__ == '__main__': ...@@ -267,6 +302,8 @@ if __name__ == '__main__':
help='Number of top-k groups (default: `min(num_nodes, 4)`)') help='Number of top-k groups (default: `min(num_nodes, 4)`)')
parser.add_argument('--num-topk', type=int, default=8, parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)') help='Number of top-k experts (default: 8)')
parser.add_argument('--pressure-test-mode', type=int, default=0,
help='Pressure test mode. 0: don\'t do pressure test, 1: do pressure test without benchmarks, 2: do pressure test with benchmarks')
parser.add_argument('--num-experts', type=int, default=256, parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: 256') help='Number of experts (default: 256')
parser.add_argument('--test-ll-compatibility', action='store_true', parser.add_argument('--test-ll-compatibility', action='store_true',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment