Unverified Commit 6b17f4fa authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

Use CLI args instead of envs (#273)



* use cli arg for num_processes
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>

* update low-latency
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>

* update intranode
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>

* update internode
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>

---------
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 341bb961
...@@ -11,13 +11,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un ...@@ -11,13 +11,13 @@ from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_un
import test_low_latency import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
# Settings # Settings
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096')) num_tokens = args.num_tokens
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168')) hidden = args.hidden
num_topk_groups = int(os.environ.get('EP_TEST_NUM_TOPK_GROUPS', str(min(num_nodes, 4)))) num_topk_groups = args.num_topk_groups
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8')) num_topk = args.num_topk
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks))) num_experts = args.num_experts
assert num_experts % num_ranks == 0 and num_local_ranks == 8 assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0: if local_rank == 0:
...@@ -224,7 +224,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in ...@@ -224,7 +224,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int, args):
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False) test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False)
...@@ -240,7 +240,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -240,7 +240,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (num_sms, ): for i in (num_sms, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
...@@ -255,5 +255,30 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -255,5 +255,30 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8')) import argparse
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) parser = argparse.ArgumentParser(description='Test internode expert parallel')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096,
help='Number of tokens (default: 4096)')
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk-groups', type=int, default=None,
help='Number of top-k groups (default: min(num_nodes, 4))')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
args = parser.parse_args()
# Set default num_topk_groups if not provided
if args.num_topk_groups is None:
num_nodes = int(os.getenv('WORLD_SIZE', 1))
args.num_topk_groups = min(num_nodes, 4)
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
...@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to ...@@ -11,12 +11,12 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
import test_low_latency import test_low_latency
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
# Settings # Settings
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096')) num_tokens = args.num_tokens
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168')) hidden = args.hidden
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8')) num_topk = args.num_topk
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks))) num_experts = args.num_experts
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
if local_rank == 0: if local_rank == 0:
...@@ -230,7 +230,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: ...@@ -230,7 +230,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int, args):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0 test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility: if test_ll_compatibility:
...@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (24, ): for i in (24, ):
test_main(i, local_rank, num_ranks, rank, buffer, group) test_main(i, local_rank, num_ranks, rank, buffer, group, args)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
...@@ -257,5 +257,23 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -257,5 +257,23 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8')) import argparse
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) parser = argparse.ArgumentParser(description='Test intranode expert parallel')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=4096,
help='Number of tokens (default: 4096)')
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=None,
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
args = parser.parse_args()
# Set default num_experts if not provided
if args.num_experts is None:
args.num_experts = (256 // args.num_processes) * args.num_processes
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
...@@ -157,12 +157,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -157,12 +157,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int, args):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '128')) num_tokens = args.num_tokens
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168')) hidden = args.hidden
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8')) num_topk = args.num_topk
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', '288')) num_experts = args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
...@@ -186,5 +186,19 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -186,5 +186,19 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8')) import argparse
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) parser = argparse.ArgumentParser(description='Test low latency expert parallel')
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128,
help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288,
help='Number of experts (default: 288)')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment