Commit f825c99e authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

Remove unnecessary import (#296)

* Remove unnecessary import

* fix
parent 1455cb37
......@@ -2,20 +2,19 @@ from __future__ import division
import torch
from . import _six # noqa:F401
import math
from torch import Tensor
from typing import Tuple
# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
# type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
......@@ -43,7 +42,7 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
......@@ -96,7 +95,7 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (Tensor, Tensor, float) -> Tensor
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
......@@ -136,7 +135,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
coordinates = coordinates.detach()
cell = cell.detach()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment