Commit 01f49071 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Unify testing envs' naming

parent 8dcdd349
...@@ -13,11 +13,11 @@ import test_low_latency ...@@ -13,11 +13,11 @@ import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096")) num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096'))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168")) hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
num_topk_groups = int(os.environ.get("DEEPEP_TEST_NUM_TOPK_GROUPS", str(min(num_nodes, 4)))) num_topk_groups = int(os.environ.get('EP_TEST_NUM_TOPK_GROUPS', str(min(num_nodes, 4))))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8")) num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks))) num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks)))
assert num_experts % num_ranks == 0 and num_local_ranks == 8 assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0: if local_rank == 0:
...@@ -255,5 +255,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -255,5 +255,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8")) num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
...@@ -13,10 +13,10 @@ import test_low_latency ...@@ -13,10 +13,10 @@ import test_low_latency
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096")) num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096'))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168")) hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8")) num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks))) num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks)))
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
if local_rank == 0: if local_rank == 0:
...@@ -257,5 +257,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -257,5 +257,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8")) num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
...@@ -160,10 +160,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -160,10 +160,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "128")) num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '128'))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168")) hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8")) num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", "288")) num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', '288'))
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
...@@ -187,5 +187,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -187,5 +187,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8")) num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
...@@ -14,9 +14,6 @@ def init_dist(local_rank: int, num_local_ranks: int): ...@@ -14,9 +14,6 @@ def init_dist(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0)) node_rank = int(os.getenv('RANK', 0))
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
assert (num_local_ranks < num_processes and num_nodes == 1) or num_local_ranks == num_processes
sig = inspect.signature(dist.init_process_group) sig = inspect.signature(dist.init_process_group)
params = { params = {
'backend': 'nccl', 'backend': 'nccl',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment