Unverified Commit ea51fadb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add mypy coverage (#530)

parent dec9c0eb
......@@ -28,4 +28,4 @@ jobs:
pip install --upgrade pip
pip install mypy
- name: Type checking with mypy
run: mypy --no-site-packages --ignore-missing-imports torchani/{aev,nn,utils,models,ase}.py
run: mypy --ignore-missing-imports .
......@@ -10,7 +10,7 @@ if BUILD_CUAEV:
sys.argv.remove('--cuaev')
if not BUILD_CUAEV:
log.warn("Will not install cuaev")
log.warn("Will not install cuaev") # type: ignore
with open("README.md", "r") as fh:
long_description = fh.read()
......
......@@ -102,9 +102,9 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(()))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
r1 = torch.arange(1, num_repeats[0].item() + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1].item() + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2].item() + 1, device=cell.device)
o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([
torch.cartesian_prod(r1, r2, r3),
......@@ -348,6 +348,7 @@ class AEVComputer(torch.nn.Module):
angular_length: Final[int]
aev_length: Final[int]
sizes: Final[Tuple[int, int, int, int, int]]
triu_index: Tensor
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super().__init__()
......
......@@ -51,7 +51,7 @@ class ANIModel(torch.nn.ModuleDict):
def __init__(self, modules):
super(ANIModel, self).__init__(self.ensureOrderedDict(modules))
def forward(self, species_aev: Tuple[Tensor, Tensor],
def forward(self, species_aev: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev
......@@ -79,7 +79,7 @@ class Ensemble(torch.nn.ModuleList):
super().__init__(modules)
self.size = len(modules)
def forward(self, species_input: Tuple[Tensor, Tensor],
def forward(self, species_input: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
sum_ = 0
......@@ -95,7 +95,7 @@ class Sequential(torch.nn.ModuleList):
def __init__(self, *modules):
super(Sequential, self).__init__(modules)
def forward(self, input_: Tuple[Tensor, Tensor],
def forward(self, input_: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
for module in self:
......@@ -121,6 +121,7 @@ class SpeciesConverter(torch.nn.Module):
sequence of all supported species, in order (it is recommended to order
according to atomic number).
"""
conv_tensor: Tensor
def __init__(self, species):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment