Commit 9f7cbd15 authored by zms1999's avatar zms1999
Browse files

fix type mismatch

parent 0edd811d
......@@ -35,7 +35,7 @@ class MoEForward(Function):
x.requires_grad = True
# To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]])
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
ctx.gibs[idx] = x
ctx.gobs[idx] = y0
y.copy_(y0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment