Unverified Commit 19fc0700 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

more (#238)

parent 0a47402f
......@@ -250,5 +250,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import os
import time
import torch
import torch.distributed as dist
......@@ -252,5 +253,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import os
import random
import torch
import torch.distributed as dist
......@@ -183,5 +184,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
......@@ -13,7 +13,9 @@ def init_dist(local_rank: int, num_local_ranks: int):
port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
assert (num_local_ranks < num_processes and num_nodes == 1) or num_local_ranks == num_processes
sig = inspect.signature(dist.init_process_group)
params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment