Commit c931c484 authored by Rich Ho's avatar Rich Ho
Browse files

fix zero test

parent 5680c599
import os
import sys import sys
import json
import torch import torch
from fmoe.layers import _fmoe_general_global_forward from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
...@@ -12,7 +14,7 @@ class ConstantGate(torch.nn.Module): ...@@ -12,7 +14,7 @@ class ConstantGate(torch.nn.Module):
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp):
idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64, idx = torch.zeros((inp.shape[0], self.top_k), dtype=torch.int64,
device=inp.device) device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2 score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score return idx, score
...@@ -47,7 +49,7 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): ...@@ -47,7 +49,7 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
script=__file__ script=__file__
) )
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def _test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda() inp = torch.rand(batch_size, d_hidden).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size, model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
gate=ConstantGate).cuda() gate=ConstantGate).cuda()
...@@ -57,6 +59,9 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): ...@@ -57,6 +59,9 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 3: if len(sys.argv) >= 3:
args = json.loads(sys.argv[2]) args = json.loads(sys.argv[2])
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
args['world_size'] = torch.distributed.get_world_size() args['world_size'] = torch.distributed.get_world_size()
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment