"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "70eab38f5a908aa40cf1ab6c15e9f0a3d03f12c8"
Commit db247e85 authored by Rick Ho's avatar Rick Ho
Browse files

fix schedule with older autograd

parent 5083a736
...@@ -33,8 +33,13 @@ class MoEForward(Function): ...@@ -33,8 +33,13 @@ class MoEForward(Function):
x = x.data x = x.data
with torch.enable_grad(): with torch.enable_grad():
x.requires_grad = True x.requires_grad = True
# To skip torch autograd's version check. try:
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
except Exception as e:
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
ctx.gibs[idx] = x ctx.gibs[idx] = x
ctx.gobs[idx] = y0 ctx.gobs[idx] = y0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment