Commit b9c28810 authored by Rick Ho's avatar Rick Ho
Browse files

multi-gpu performance test

parent 76327544
...@@ -84,7 +84,7 @@ class MOEGlobal(Function): ...@@ -84,7 +84,7 @@ class MOEGlobal(Function):
def moe(inp, gate, weight, world_size): def moe(inp, gate, weight, world_size):
if world_size is not None: if world_size is not None and world_size > 1:
return MOEGlobal.apply(inp, gate, weight, world_size) return MOEGlobal.apply(inp, gate, weight, world_size)
else: else:
return MOELocal.apply(inp, gate, weight) return MOELocal.apply(inp, gate, weight)
...@@ -5,24 +5,37 @@ import time ...@@ -5,24 +5,37 @@ import time
import sys import sys
dev_name = 'cuda:1' dev_name_default = 'cuda:0'
def perf(): def perf():
torch.manual_seed(42 + torch.distributed.get_rank()) torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank()) torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = int(sys.argv[1]) if len(sys.argv) == 6:
in_feat = int(sys.argv[2]) batch_size = int(sys.argv[2])
out_feat = int(sys.argv[3]) in_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) out_feat = int(sys.argv[4])
num_expert = int(sys.argv[5])
else:
batch_size = 4096
in_feat = 1024
out_feat = 4096
num_expert = 4
if torch.distributed.get_rank() == 0:
print('Performance test case bs {} {}x{} ne {}'.format(batch_size,
in_feat, out_feat, num_expert))
if torch.distributed.get_world_size() > 1:
dev_name = 'cuda'
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name) inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0, gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(), high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name) size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat).cuda(dev_name) moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
moe.train() moe.train()
o = moe(inp, gate) o = moe(inp, gate)
...@@ -146,6 +159,14 @@ def test_dp(): ...@@ -146,6 +159,14 @@ def test_dp():
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi') torch.distributed.init_process_group(backend='mpi')
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
if len(sys.argv) == 2:
task = sys.argv[1]
print('Specificed task {}'.format(task))
if task == 'correctness':
test()
elif task == 'dp':
test_dp()
elif task == 'performance':
perf()
else:
test() test()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# perf()
...@@ -9,22 +9,6 @@ export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib ...@@ -9,22 +9,6 @@ export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib
if [ -z $1 ] if [ -z $1 ]
then then
python3 moe_test.py 2>logs/$OMPI_COMM_WORLD_RANK.log python3 moe_test.py 2>logs/$OMPI_COMM_WORLD_RANK.log
elif [ .$1 = '.test_all' ]
then
for nexp in 1 2 4
do
for inf in 1024
do
for ouf in 4096
do
for bs in 4 16 64 256 512 1024 2048 4096
do
echo $bs $nexp ${inf}x${ouf}
python3 moe_test.py $bs $inf $ouf $nexp
done
done
done
done
else else
python3 $@ # 2>logs/$OMPI_COMM_WORLD_RANK.log python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment