"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "fc0309615e42c32989e060e733d871e16617874e"
Commit 191c1e46 authored by Rick Ho's avatar Rick Ho
Browse files

better timing

parent 365b6f01
...@@ -15,19 +15,30 @@ def perf(): ...@@ -15,19 +15,30 @@ def perf():
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
maxt = 0.
sqtot = 0.
for i in range(n_runs): for i in range(n_runs):
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp, gate)
te = time.time() te = time.time()
tott += te - ts tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Mean time {:.3f} ms, {:.3f} GFLOPs'.format(tott * 1e3 / n_runs, gflops)) print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment