Unverified Commit 28b09047 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix dimenet example (#4219)


Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 9ee7ced5
...@@ -17,7 +17,9 @@ class BesselBasisLayer(nn.Module): ...@@ -17,7 +17,9 @@ class BesselBasisLayer(nn.Module):
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi) with torch.no_grad():
torch.arange(1, self.frequencies.numel() + 1, out=self.frequencies).mul_(np.pi)
self.frequencies.requires_grad_()
def forward(self, g): def forward(self, g):
d_scaled = g.edata['d'] / self.cutoff d_scaled = g.edata['d'] / self.cutoff
...@@ -25,4 +27,4 @@ class BesselBasisLayer(nn.Module): ...@@ -25,4 +27,4 @@ class BesselBasisLayer(nn.Module):
d_scaled = torch.unsqueeze(d_scaled, -1) d_scaled = torch.unsqueeze(d_scaled, -1)
d_cutoff = self.envelope(d_scaled) d_cutoff = self.envelope(d_scaled)
g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled) g.edata['rbf'] = d_cutoff * torch.sin(self.frequencies * d_scaled)
return g return g
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment