Commit c931c484 authored by Rich Ho's avatar Rich Ho
Browse files

fix zero test

parent 5680c599
import os
import sys
import json
import torch
from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP
......@@ -12,7 +14,7 @@ class ConstantGate(torch.nn.Module):
self.top_k = top_k
def forward(self, inp):
idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64,
idx = torch.zeros((inp.shape[0], self.top_k), dtype=torch.int64,
device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score
......@@ -47,7 +49,7 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
script=__file__
)
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate).cuda()
......@@ -57,6 +59,9 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
args['world_size'] = torch.distributed.get_world_size()
locals()[sys.argv[1]](**args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment