Commit 184b0404 authored by Sengxian's avatar Sengxian
Browse files

Fix bug in DDP test

parent 5ead59db
...@@ -157,12 +157,12 @@ def test_fmoe( ...@@ -157,12 +157,12 @@ def test_fmoe(
) )
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)] para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor) torch.distributed.all_gather(para_array, para_tensor)
para_tesnor_gathered = torch.cat(para_array, dim=0) para_tensor_gathered = torch.cat(para_array, dim=0)
assert len(para_array) == len(moe_raw.experts) assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
for expertID in range(para_tesnor_gathered.shape[0]): for expertID in range(para_tensor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[idx].data = para_tensor[ list(moe_raw.experts[expertID].parameters())[
expertID idx
] ].data = para_tensor_gathered[expertID]
moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k) moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
...@@ -202,10 +202,10 @@ def _run_distributed(func: Callable, args: Dict): ...@@ -202,10 +202,10 @@ def _run_distributed(func: Callable, args: Dict):
ps, n = [], 2 ps, n = [], 2
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666" os.environ["MASTER_PORT"] = "36666"
os.environ["WORLD_SIZE"] = str(n) os.environ["OMPI_COMM_WORLD_SIZE"] = str(n)
for i in range(n): for i in range(n):
os.environ["RANK"] = str(i) os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
os.environ["CUDA_VISIBLE_DEVICES"] = str(i) os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen( p = subprocess.Popen(
[sys.executable, __file__, func.__name__, json.dumps(args)], [sys.executable, __file__, func.__name__, json.dumps(args)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment