Commit 254ad118 authored by Rick Ho's avatar Rick Ho
Browse files

test support for fused fw and bw

parent 35addec6
......@@ -6,14 +6,14 @@ import sys
def perf():
batch_size = int(sys.argv[1])
in_feat = int(sys.argv[2])
out_feat = int(sys.argv[3])
io_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3])
num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, in_feat).cuda()
inp = torch.rand(batch_size, io_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda()
o = moe(inp, gate)
o = moe(inp, gate)
......@@ -35,7 +35,7 @@ def perf():
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
gflops = 2e-9 * n_runs * io_feat * hidden_feat * 2 * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment