Commit 7705f533 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Refactor testing arguments

parent 6b17f4fa
import argparse
import os import os
import time import time
import torch import torch
...@@ -11,13 +12,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un ...@@ -11,13 +12,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un
import test_low_latency import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args): # noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int,
local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens = args.num_tokens num_tokens, hidden = args.num_tokens, args.hidden
hidden = args.hidden num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
num_topk_groups = args.num_topk_groups
num_topk = args.num_topk
num_experts = args.num_experts
assert num_experts % num_ranks == 0 and num_local_ranks == 8 assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0: if local_rank == 0:
...@@ -223,29 +224,28 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in ...@@ -223,29 +224,28 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
print('', flush=True) print('', flush=True)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False) if args.test_ll_compatibility:
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_sms = 24 num_sms = 24
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (num_sms, ): for i in (num_sms, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args) test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
# Test compatibility with low latency functions # Test compatibility with low latency functions
if test_ll_compatibility: if args.test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
...@@ -255,8 +255,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args): ...@@ -255,8 +255,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__': if __name__ == '__main__':
import argparse parser = argparse.ArgumentParser(description='Test internode EP kernels')
parser = argparse.ArgumentParser(description='Test internode expert parallel')
parser.add_argument('--num-processes', type=int, default=8, parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)') help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096, parser.add_argument('--num-tokens', type=int, default=4096,
...@@ -264,21 +263,19 @@ if __name__ == '__main__': ...@@ -264,21 +263,19 @@ if __name__ == '__main__':
parser.add_argument('--hidden', type=int, default=7168, parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)') help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk-groups', type=int, default=None, parser.add_argument('--num-topk-groups', type=int, default=None,
help='Number of top-k groups (default: min(num_nodes, 4))') help='Number of top-k groups (default: `min(num_nodes, 4)`)')
parser.add_argument('--num-topk', type=int, default=8, parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)') help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None, parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)') help='Number of experts (default: 256')
parser.add_argument('--test-ll-compatibility', action='store_true',
help='whether to test compatibility with low-latency kernels')
args = parser.parse_args() args = parser.parse_args()
# Set default num_topk_groups if not provided # Set default `num_topk_groups` if not provided
if args.num_topk_groups is None: if args.num_topk_groups is None:
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
args.num_topk_groups = min(num_nodes, 4) args.num_topk_groups = min(num_nodes, 4)
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
import os import argparse
import time import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to ...@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
import test_low_latency import test_low_latency
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args): # noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens = args.num_tokens num_tokens, hidden = args.num_tokens, args.hidden
hidden = args.hidden num_topk, num_experts = args.num_topk, args.num_experts
num_topk = args.num_topk
num_experts = args.num_experts
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
if local_rank == 0: if local_rank == 0:
...@@ -229,8 +229,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: ...@@ -229,8 +229,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
print('', flush=True) print('', flush=True)
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0 test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility: if test_ll_compatibility:
...@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args): ...@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (24, ): for i in (24, ):
test_main(i, local_rank, num_ranks, rank, buffer, group, args) test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
...@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args): ...@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__': if __name__ == '__main__':
import argparse parser = argparse.ArgumentParser(description='Test intranode EP kernels')
parser = argparse.ArgumentParser(description='Test intranode expert parallel')
parser.add_argument('--num-processes', type=int, default=8, parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)') help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096, parser.add_argument('--num-tokens', type=int, default=4096,
...@@ -267,13 +266,9 @@ if __name__ == '__main__': ...@@ -267,13 +266,9 @@ if __name__ == '__main__':
help='Hidden dimension size (default: 7168)') help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)') help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None, parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)') help='Number of experts (default: 256)')
args = parser.parse_args() args = parser.parse_args()
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
import os import argparse
import random import random
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit # NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128 rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
...@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}' assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -156,13 +146,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -156,13 +146,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
return hash_value return hash_value
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens = args.num_tokens num_tokens, hidden = args.num_tokens, args.hidden
hidden = args.hidden num_topk, num_experts = args.num_topk, args.num_experts
num_topk = args.num_topk
num_experts = args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
...@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args): ...@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__': if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
import argparse parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser = argparse.ArgumentParser(description='Test low latency expert parallel')
parser.add_argument('--num-processes', type=int, default=8, parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)') help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, parser.add_argument('--num-tokens', type=int, default=128,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment