Unverified Commit 81e6150c authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

__constants__ is deprecated by torch.jit (#378)

* __constants__ is deprecated

* commit
parent d32081e9
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch import Tensor from torch import Tensor
import math import math
from typing import Tuple, Optional from typing import Tuple, Optional
from torch.jit import Final
def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor: def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
...@@ -226,9 +227,9 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor ...@@ -226,9 +227,9 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, triu_index: Tensor, shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor], constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int, int]) -> Tensor: sizes: Tuple[int, int, int, int, int]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes
num_molecules = species.shape[0] num_molecules = species.shape[0]
num_atoms = species.shape[1] num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
...@@ -300,15 +301,24 @@ class AEVComputer(torch.nn.Module): ...@@ -300,15 +301,24 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
__constants__ = ['Rcr', 'Rca', 'num_species', 'radial_sublength', Rcr: Final[float]
'radial_length', 'angular_sublength', 'angular_length', Rca: Final[float]
'aev_length'] num_species: Final[int]
radial_sublength: Final[int]
radial_length: Final[int]
angular_sublength: Final[int]
angular_length: Final[int]
aev_length: Final[int]
sizes: Final[Tuple[int, int, int, int, int]]
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species): def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super(AEVComputer, self).__init__() super(AEVComputer, self).__init__()
self.Rcr = Rcr self.Rcr = Rcr
self.Rca = Rca self.Rca = Rca
assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr" assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr"
self.num_species = num_species
# convert constant tensors to a ready-to-broadcast shape # convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR) # shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1)) self.register_buffer('EtaR', EtaR.view(-1, 1))
...@@ -319,7 +329,6 @@ class AEVComputer(torch.nn.Module): ...@@ -319,7 +329,6 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1)) self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1))
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1)) self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.num_species = num_species
# The length of radial subaev of a single species # The length of radial subaev of a single species
self.radial_sublength = self.EtaR.numel() * self.ShfR.numel() self.radial_sublength = self.EtaR.numel() * self.ShfR.numel()
# The length of full radial aev # The length of full radial aev
...@@ -330,7 +339,7 @@ class AEVComputer(torch.nn.Module): ...@@ -330,7 +339,7 @@ class AEVComputer(torch.nn.Module):
self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength
# The length of full aev # The length of full aev
self.aev_length = self.radial_length + self.angular_length self.aev_length = self.radial_length + self.angular_length
self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length, self.aev_length self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length
self.register_buffer('triu_index', triu_index(num_species).to(device=self.EtaR.device)) self.register_buffer('triu_index', triu_index(num_species).to(device=self.EtaR.device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment