Commit 184b0404 authored by Sengxian's avatar Sengxian
Browse files

Fix bug in DDP test

parent 5ead59db
......@@ -157,12 +157,12 @@ def test_fmoe(
)
para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
torch.distributed.all_gather(para_array, para_tensor)
para_tesnor_gathered = torch.cat(para_array, dim=0)
assert len(para_array) == len(moe_raw.experts)
for expertID in range(para_tesnor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[idx].data = para_tensor[
expertID
]
para_tensor_gathered = torch.cat(para_array, dim=0)
assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
for expertID in range(para_tensor_gathered.shape[0]):
list(moe_raw.experts[expertID].parameters())[
idx
].data = para_tensor_gathered[expertID]
moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
......@@ -202,10 +202,10 @@ def _run_distributed(func: Callable, args: Dict):
ps, n = [], 2
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
os.environ["WORLD_SIZE"] = str(n)
os.environ["OMPI_COMM_WORLD_SIZE"] = str(n)
for i in range(n):
os.environ["RANK"] = str(i)
os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
p = subprocess.Popen(
[sys.executable, __file__, func.__name__, json.dumps(args)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment