Commit 20cc924b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent 046455a8
...@@ -20,7 +20,7 @@ def perf(): ...@@ -20,7 +20,7 @@ def perf():
n_runs = 16 n_runs = 16
tott = 0. tott = 0.
for i in range(n_runs): for i in range(n_runs):
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda("cuda:1")
ts = time.time() ts = time.time()
o = moe(inp, gate) o = moe(inp, gate)
te = time.time() te = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment