Unverified Commit ea51fadb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add mypy coverage (#530)

parent dec9c0eb
...@@ -28,4 +28,4 @@ jobs: ...@@ -28,4 +28,4 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install mypy pip install mypy
- name: Type checking with mypy - name: Type checking with mypy
run: mypy --no-site-packages --ignore-missing-imports torchani/{aev,nn,utils,models,ase}.py run: mypy --ignore-missing-imports .
...@@ -10,7 +10,7 @@ if BUILD_CUAEV: ...@@ -10,7 +10,7 @@ if BUILD_CUAEV:
sys.argv.remove('--cuaev') sys.argv.remove('--cuaev')
if not BUILD_CUAEV: if not BUILD_CUAEV:
log.warn("Will not install cuaev") log.warn("Will not install cuaev") # type: ignore
with open("README.md", "r") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
......
...@@ -102,9 +102,9 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor: ...@@ -102,9 +102,9 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
inv_distances = reciprocal_cell.norm(2, -1) inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(())) num_repeats = torch.where(pbc, num_repeats, num_repeats.new_zeros(()))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device) r1 = torch.arange(1, num_repeats[0].item() + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device) r2 = torch.arange(1, num_repeats[1].item() + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device) r3 = torch.arange(1, num_repeats[2].item() + 1, device=cell.device)
o = torch.zeros(1, dtype=torch.long, device=cell.device) o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([ return torch.cat([
torch.cartesian_prod(r1, r2, r3), torch.cartesian_prod(r1, r2, r3),
...@@ -348,6 +348,7 @@ class AEVComputer(torch.nn.Module): ...@@ -348,6 +348,7 @@ class AEVComputer(torch.nn.Module):
angular_length: Final[int] angular_length: Final[int]
aev_length: Final[int] aev_length: Final[int]
sizes: Final[Tuple[int, int, int, int, int]] sizes: Final[Tuple[int, int, int, int, int]]
triu_index: Tensor
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species): def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super().__init__() super().__init__()
......
...@@ -51,7 +51,7 @@ class ANIModel(torch.nn.ModuleDict): ...@@ -51,7 +51,7 @@ class ANIModel(torch.nn.ModuleDict):
def __init__(self, modules): def __init__(self, modules):
super(ANIModel, self).__init__(self.ensureOrderedDict(modules)) super(ANIModel, self).__init__(self.ensureOrderedDict(modules))
def forward(self, species_aev: Tuple[Tensor, Tensor], def forward(self, species_aev: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
species, aev = species_aev species, aev = species_aev
...@@ -79,7 +79,7 @@ class Ensemble(torch.nn.ModuleList): ...@@ -79,7 +79,7 @@ class Ensemble(torch.nn.ModuleList):
super().__init__(modules) super().__init__(modules)
self.size = len(modules) self.size = len(modules)
def forward(self, species_input: Tuple[Tensor, Tensor], def forward(self, species_input: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
sum_ = 0 sum_ = 0
...@@ -95,7 +95,7 @@ class Sequential(torch.nn.ModuleList): ...@@ -95,7 +95,7 @@ class Sequential(torch.nn.ModuleList):
def __init__(self, *modules): def __init__(self, *modules):
super(Sequential, self).__init__(modules) super(Sequential, self).__init__(modules)
def forward(self, input_: Tuple[Tensor, Tensor], def forward(self, input_: Tuple[Tensor, Tensor], # type: ignore
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None): pbc: Optional[Tensor] = None):
for module in self: for module in self:
...@@ -121,6 +121,7 @@ class SpeciesConverter(torch.nn.Module): ...@@ -121,6 +121,7 @@ class SpeciesConverter(torch.nn.Module):
sequence of all supported species, in order (it is recommended to order sequence of all supported species, in order (it is recommended to order
according to atomic number). according to atomic number).
""" """
conv_tensor: Tensor
def __init__(self, species): def __init__(self, species):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment