Commit f825c99e authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Remove unnecessary import (#296)

* Remove unnecessary import

* fix
parent 1455cb37
...@@ -2,20 +2,19 @@ from __future__ import division ...@@ -2,20 +2,19 @@ from __future__ import division
import torch import torch
from . import _six # noqa:F401 from . import _six # noqa:F401
import math import math
from torch import Tensor
from typing import Tuple from typing import Tuple
# @torch.jit.script # @torch.jit.script
def cutoff_cosine(distances, cutoff): def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor # type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff # assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5 return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
# @torch.jit.script # @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances): def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor # type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just This correspond to equation (3) in the `ANI paper`_. This function just
...@@ -43,7 +42,7 @@ def radial_terms(Rcr, EtaR, ShfR, distances): ...@@ -43,7 +42,7 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
# @torch.jit.script # @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor # type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just This correspond to equation (4) in the `ANI paper`_. This function just
...@@ -96,7 +95,7 @@ def compute_shifts(cell, pbc, cutoff): ...@@ -96,7 +95,7 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and :class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included. symmetric cells are not included.
""" """
# type: (Tensor, Tensor, float) -> Tensor # type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t() reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1) inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
...@@ -136,7 +135,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -136,7 +135,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
""" """
# type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor] # type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
coordinates = coordinates.detach() coordinates = coordinates.detach()
cell = cell.detach() cell = cell.detach()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment