Unverified Commit 19fc0700 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

more (#238)

parent 0a47402f
...@@ -250,5 +250,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -250,5 +250,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = 8 num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import os
import time import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -252,5 +253,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -252,5 +253,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
num_processes = 8 num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
import os
import random import random
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -183,5 +184,5 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -183,5 +184,5 @@ def test_loop(local_rank: int, num_local_ranks: int):
if __name__ == '__main__': if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8 num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
...@@ -13,7 +13,9 @@ def init_dist(local_rank: int, num_local_ranks: int): ...@@ -13,7 +13,9 @@ def init_dist(local_rank: int, num_local_ranks: int):
port = int(os.getenv('MASTER_PORT', '8361')) port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0)) node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
assert (num_local_ranks < num_processes and num_nodes == 1) or num_local_ranks == num_processes
sig = inspect.signature(dist.init_process_group) sig = inspect.signature(dist.init_process_group)
params = { params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment