Commit 7705f533 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Refactor testing arguments

parent 6b17f4fa
import argparse
import os
import time
import torch
......@@ -11,13 +12,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
# noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int,
local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens = args.num_tokens
hidden = args.hidden
num_topk_groups = args.num_topk_groups
num_topk = args.num_topk
num_experts = args.num_experts
num_tokens, hidden = args.num_tokens, args.hidden
num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
......@@ -223,29 +224,28 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
print('', flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args):
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False)
if test_ll_compatibility:
if args.test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_sms = 24
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0)
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (num_sms, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args)
test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
if args.test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
......@@ -255,8 +255,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Test internode expert parallel')
parser = argparse.ArgumentParser(description='Test internode EP kernels')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096,
......@@ -264,21 +263,19 @@ if __name__ == '__main__':
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk-groups', type=int, default=None,
help='Number of top-k groups (default: min(num_nodes, 4))')
help='Number of top-k groups (default: `min(num_nodes, 4)`)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: 256')
parser.add_argument('--test-ll-compatibility', action='store_true',
help='whether to test compatibility with low-latency kernels')
args = parser.parse_args()
# Set default num_topk_groups if not provided
# Set default `num_topk_groups` if not provided
if args.num_topk_groups is None:
num_nodes = int(os.getenv('WORLD_SIZE', 1))
args.num_topk_groups = min(num_nodes, 4)
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
import os
import argparse
import time
import torch
import torch.distributed as dist
......@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
# noinspection PyShadowingNames
def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int,
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens = args.num_tokens
hidden = args.hidden
num_topk = args.num_topk
num_experts = args.num_experts
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
assert num_experts % num_ranks == 0
if local_rank == 0:
......@@ -229,8 +229,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
print('', flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args):
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
......@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_ranks, rank, buffer, group, args)
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
......@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Test intranode expert parallel')
parser = argparse.ArgumentParser(description='Test intranode EP kernels')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096,
......@@ -267,13 +266,9 @@ if __name__ == '__main__':
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
parser.add_argument('--num-experts', type=int, default=256,
help='Number of experts (default: 256)')
args = parser.parse_args()
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
import os
import argparse
import random
import torch
import torch.distributed as dist
......@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
......@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
......@@ -156,13 +146,11 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args):
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens = args.num_tokens
hidden = args.hidden
num_topk = args.num_topk
num_experts = args.num_experts
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
......@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
import argparse
parser = argparse.ArgumentParser(description='Test low latency expert parallel')
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment