Unverified Commit 2e5032a4 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Small modifications in code and comments to clarify (#11) (#539)


Co-authored-by: default avatarIgnacio Pickering <ign.pickering@gmail.com>
parent 25bd59fb
...@@ -36,17 +36,17 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T ...@@ -36,17 +36,17 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
distances = distances.unsqueeze(-1).unsqueeze(-1) distances = distances.view(-1, 1, 1)
fc = cutoff_cosine(distances, Rcr) fc = cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25 # Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient. # coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here. # We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc
# At this point, ret now have shape # At this point, ret now has shape
# (conformations, atoms, N, ?, ?) where ? depend on constants. # (conformations x atoms, ?, ?) where ? depend on constants.
# We then should flat the last 2 dimensions to view the subAEV as one # We then should flat the last 2 dimensions to view the subAEV as a two
# dimension vector # dimensional tensor (onnx doesn't support negative indices in flatten)
return ret.flatten(start_dim=-2) return ret.flatten(start_dim=1)
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
...@@ -63,7 +63,7 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, ...@@ -63,7 +63,7 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
distances12 = vectors12.norm(2, dim=-5) distances12 = vectors12.norm(2, dim=-5)
cos_angles = vectors12.prod(0).sum(1) / distances12.prod(0) cos_angles = vectors12.prod(0).sum(1) / distances12.prod(0)
...@@ -74,11 +74,11 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, ...@@ -74,11 +74,11 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2) factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj12.prod(0) ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape # At this point, ret now has shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants. # (conformations x atoms, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one # We then should flat the last 4 dimensions to view the subAEV as a two
# dimension vector # dimensional tensor (onnx doesn't support negative indices in flatten)
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=1)
def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor: def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment