Unverified Commit 28b09047 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix dimenet example (#4219)


Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 9ee7ced5
......@@ -17,7 +17,9 @@ class BesselBasisLayer(nn.Module):
self.reset_params()
def reset_params(self):
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
with torch.no_grad():
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
self.frequencies.requires_grad_()
def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff
......@@ -25,4 +27,4 @@ class BesselBasisLayer(nn.Module):
d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g
\ No newline at end of file
return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment