Unverified Commit 31bf913d authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Raise error for unknown species (#512)

* raise error for unknown species

* make it jittable

* fix dimension bug
parent 5fbd9edd
...@@ -439,6 +439,7 @@ class AEVComputer(torch.nn.Module): ...@@ -439,6 +439,7 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())`` unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())``
""" """
species, coordinates = input_ species, coordinates = input_
assert species.shape == coordinates.shape[:-1]
if cell is None and pbc is None: if cell is None and pbc is None:
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None) aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
......
...@@ -139,6 +139,11 @@ class BuiltinModel(torch.nn.Module): ...@@ -139,6 +139,11 @@ class BuiltinModel(torch.nn.Module):
""" """
if self.periodic_table_index: if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates) species_coordinates = self.species_converter(species_coordinates)
# check if unknown species are included
if species_coordinates[0].ge(self.aev_computer.num_species).any():
raise ValueError(f'Unknown species found in {species_coordinates[0]}')
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs) species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies) return self.energy_shifter(species_energies)
......
...@@ -55,6 +55,8 @@ class ANIModel(torch.nn.ModuleDict): ...@@ -55,6 +55,8 @@ class ANIModel(torch.nn.ModuleDict):
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev species, aev = species_aev
assert species.shape == aev.shape[:-1]
species_ = species.flatten() species_ = species.flatten()
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
...@@ -133,4 +135,10 @@ class SpeciesConverter(torch.nn.Module): ...@@ -133,4 +135,10 @@ class SpeciesConverter(torch.nn.Module):
pbc: Optional[Tensor] = None): pbc: Optional[Tensor] = None):
"""Convert species from periodic table element index to 0, 1, 2, 3, ... indexing""" """Convert species from periodic table element index to 0, 1, 2, 3, ... indexing"""
species, coordinates = input_ species, coordinates = input_
return SpeciesCoordinates(self.conv_tensor[species].to(species.device), coordinates) converted_species = self.conv_tensor[species]
# check if unknown species are included
if converted_species[species.ne(-1)].lt(0).any():
raise ValueError(f'Unknown species found in {species}')
return SpeciesCoordinates(converted_species.to(species.device), coordinates)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment