Unverified Commit a3720bc2 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Fix handwritten cosine similarity and add test for NaNs in AEVComputer (#561)

parent d7302cc3
...@@ -118,6 +118,13 @@ class TestAEV(_TestAEVBase): ...@@ -118,6 +118,13 @@ class TestAEV(_TestAEVBase):
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
self.assertAEVEqual(expected_radial, expected_angular, aev) self.assertAEVEqual(expected_radial, expected_angular, aev)
def testNoNan(self):
# AEV should not output NaN even when coordinates are superimposed
coordinates = torch.ones(1, 3, 3, dtype=torch.float)
species = torch.zeros(1, 3, dtype=torch.long)
_, aev = self.aev_computer((species, coordinates))
self.assertFalse(torch.isnan(aev).any())
def testPadding(self): def testPadding(self):
species_coordinates = [] species_coordinates = []
radial_angular = [] radial_angular = []
......
...@@ -76,8 +76,7 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, ...@@ -76,8 +76,7 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
""" """
vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1) vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
distances12 = vectors12.norm(2, dim=-5) distances12 = vectors12.norm(2, dim=-5)
cos_angles = vectors12.prod(0).sum(1) / torch.clamp(distances12.prod(0), min=1e-10)
cos_angles = vectors12.prod(0).sum(1) / distances12.prod(0)
# 0.95 is multiplied to the cos values to prevent acos from returning NaN. # 0.95 is multiplied to the cos values to prevent acos from returning NaN.
angles = torch.acos(0.95 * cos_angles) angles = torch.acos(0.95 * cos_angles)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment