Commit 9f7cbd15 authored by zms1999's avatar zms1999
Browse files

fix type mismatch

parent 0edd811d
...@@ -35,7 +35,7 @@ class MoEForward(Function): ...@@ -35,7 +35,7 @@ class MoEForward(Function):
x.requires_grad = True x.requires_grad = True
# To skip torch autograd's version check. # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]]) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
ctx.gibs[idx] = x ctx.gibs[idx] = x
ctx.gobs[idx] = y0 ctx.gobs[idx] = y0
y.copy_(y0) y.copy_(y0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment