Unverified Commit 98bb1237 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use modern type annotations for aev.py (#372)

* Use modern type annotations for aev.py

* commit
parent 7499c8d4
import torch import torch
from torch import Tensor
import math import math
from typing import Tuple, Optional from typing import Tuple, Optional
def cutoff_cosine(distances, cutoff): def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff # assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5 return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
def radial_terms(Rcr, EtaR, ShfR, distances): def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> Tensor:
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just This correspond to equation (3) in the `ANI paper`_. This function just
...@@ -36,8 +35,8 @@ def radial_terms(Rcr, EtaR, ShfR, distances): ...@@ -36,8 +35,8 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2) return ret.flatten(start_dim=-2)
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor ShfA: Tensor, vectors1: Tensor, vectors2: Tensor) -> Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just This correspond to equation (4) in the `ANI paper`_. This function just
...@@ -72,8 +71,7 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): ...@@ -72,8 +71,7 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
def compute_shifts(cell, pbc, cutoff): def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
"""Compute the shifts of unit cell along the given cell vectors to make it """Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under large enough to contain all pairs of neighbor atoms with PBC under
consideration consideration
...@@ -115,8 +113,8 @@ def compute_shifts(cell, pbc, cutoff): ...@@ -115,8 +113,8 @@ def compute_shifts(cell, pbc, cutoff):
]) ])
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors """Compute pairs of atoms that are neighbors
Arguments: Arguments:
...@@ -164,8 +162,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -164,8 +162,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts return molecule_index + atom_index1, molecule_index + atom_index2, shifts
def triu_index(num_species): def triu_index(num_species: int) -> Tensor:
# type: (int) -> torch.Tensor
species1, species2 = torch.triu_indices(num_species, num_species).unbind(0) species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
pair_index = torch.arange(species1.shape[0], dtype=torch.long) pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long) ret = torch.zeros(num_species, num_species, dtype=torch.long)
...@@ -174,15 +171,13 @@ def triu_index(num_species): ...@@ -174,15 +171,13 @@ def triu_index(num_species):
return ret return ret
def cumsum_from_zero(input_): def cumsum_from_zero(input_: Tensor) -> Tensor:
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0) cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]]) cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]])
return cumsum return cumsum
def triple_by_molecule(atom_index1, atom_index2): def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""Input: indices for pairs of atoms that are close to each other. """Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists. (2, 1) exists.
...@@ -228,8 +223,10 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -228,8 +223,10 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2 return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes): def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int, int]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0] num_molecules = species.shape[0]
...@@ -349,8 +346,8 @@ class AEVComputer(torch.nn.Module): ...@@ -349,8 +346,8 @@ class AEVComputer(torch.nn.Module):
def constants(self): def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_, cell=None, pbc=None): def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
# type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Compute AEVs """Compute AEVs
Arguments: Arguments:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment