Unverified Commit 98bb1237 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use modern type annotations for aev.py (#372)

* Use modern type annotations for aev.py

* commit
parent 7499c8d4
import torch
from torch import Tensor
import math
from typing import Tuple, Optional
def cutoff_cosine(distances, cutoff):
# type: (torch.Tensor, float) -> torch.Tensor
def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> Tensor:
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
......@@ -36,8 +35,8 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2)
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
ShfA: Tensor, vectors1: Tensor, vectors2: Tensor) -> Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
......@@ -72,8 +71,7 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)
def compute_shifts(cell, pbc, cutoff):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
......@@ -115,8 +113,8 @@ def compute_shifts(cell, pbc, cutoff):
])
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors
Arguments:
......@@ -164,8 +162,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts
def triu_index(num_species):
# type: (int) -> torch.Tensor
def triu_index(num_species: int) -> Tensor:
species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long)
......@@ -174,15 +171,13 @@ def triu_index(num_species):
return ret
def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
def cumsum_from_zero(input_: Tensor) -> Tensor:
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]])
return cumsum
def triple_by_molecule(atom_index1, atom_index2):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
......@@ -228,8 +223,10 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int, int]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
......@@ -349,8 +346,8 @@ class AEVComputer(torch.nn.Module):
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_, cell=None, pbc=None):
# type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Compute AEVs
Arguments:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment