Unverified Commit c4f5d3b4 authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

CUAEV refactor to initialize aev parameter only once (#575)



* init

* separate to h cpp cu files

* works

* fix

* clean

* update

* fix

* only keep one copy of Result

* const result&

* move declaration to header

* fix jit

* clean

* fix

* fix jit

* fix ci test

* clean

* revert to autograd in charge of it's result

* update

* fix memory issue of result.aev_t

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* save

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* update

* Update torchani/cuaev/cuaev.cpp
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/cuaev.cpp
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* remove aev_t after forward

* remove release function

* fix

* fix jit

* update
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 191cf2d1
import os import os
import glob
import subprocess import subprocess
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils import log from distutils import log
...@@ -91,11 +90,12 @@ def cuda_extension(build_all=False): ...@@ -91,11 +90,12 @@ def cuda_extension(build_all=False):
nvcc_args.append("-gencode=arch=compute_86,code=sm_86") nvcc_args.append("-gencode=arch=compute_86,code=sm_86")
print("nvcc_args: ", nvcc_args) print("nvcc_args: ", nvcc_args)
print('-' * 75) print('-' * 75)
include_dirs = [*maybe_download_cub(), os.path.abspath("torchani/cuaev/")]
return CUDAExtension( return CUDAExtension(
name='torchani.cuaev', name='torchani.cuaev',
pkg='torchani.cuaev', pkg='torchani.cuaev',
sources=glob.glob('torchani/cuaev/*.cu'), sources=["torchani/cuaev/cuaev.cpp", "torchani/cuaev/aev.cu"],
include_dirs=maybe_download_cub(), include_dirs=include_dirs,
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args}) extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args})
......
...@@ -18,9 +18,10 @@ class TestCUAEVNoGPU(TestCase): ...@@ -18,9 +18,10 @@ class TestCUAEVNoGPU(TestCase):
def testSimple(self): def testSimple(self):
def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int): def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int):
return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) cuaev_computer = torch.classes.cuaev.CuaevComputer(Rcr, Rca, EtaR.flatten(), ShfR.flatten(), EtaA.flatten(), Zeta.flatten(), ShfA.flatten(), ShfZ.flatten(), num_species)
return torch.ops.cuaev.run(coordinates, species, cuaev_computer)
s = torch.jit.script(f) s = torch.jit.script(f)
self.assertIn("cuaev::cuComputeAEV", str(s.graph)) self.assertIn("cuaev::run", str(s.graph))
def testAEVComputer(self): def testAEVComputer(self):
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
...@@ -31,7 +32,7 @@ class TestCUAEVNoGPU(TestCase): ...@@ -31,7 +32,7 @@ class TestCUAEVNoGPU(TestCase):
# Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU # Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU
species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4) species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4)
coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5) coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5)
self.assertIn("cuaev::cuComputeAEV", str(s.graph_for((species, coordinates)))) self.assertIn("cuaev::run", str(s.graph_for((species, coordinates))))
@skipIfNoGPU @skipIfNoGPU
......
...@@ -323,18 +323,12 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor, ...@@ -323,18 +323,12 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
def compute_cuaev(species: Tensor, coordinates: Tensor, triu_index: Tensor, def jit_unused_if_no_cuaev(condition=has_cuaev):
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor], def decorator(func):
num_species: int, cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor: if not condition:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants return torch.jit.unused(func)
return func
assert cell_shifts is None, "Current implementation of cuaev does not support pbc." return decorator
species_int = species.to(torch.int32)
return torch.ops.cuaev.cuComputeAEV(coordinates, species_int, Rcr, Rca, EtaR.flatten(), ShfR.flatten(), EtaA.flatten(), Zeta.flatten(), ShfA.flatten(), ShfZ.flatten(), num_species)
if not has_cuaev:
compute_cuaev = torch.jit.unused(compute_cuaev)
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
...@@ -421,6 +415,24 @@ class AEVComputer(torch.nn.Module): ...@@ -421,6 +415,24 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('default_cell', default_cell) self.register_buffer('default_cell', default_cell)
self.register_buffer('default_shifts', default_shifts) self.register_buffer('default_shifts', default_shifts)
# Should create only when use_cuda_extension is True.
# However jit needs to know cuaev_computer's Type even when use_cuda_extension is False, because it is enabled when cuaev is available
if has_cuaev:
self.init_cuaev_computer()
# When has_cuaev is true, and use_cuda_extension is false, and user enable use_cuda_extension afterwards,
# then another init_cuaev_computer will be needed
self.cuaev_enabled = True if self.use_cuda_extension else False
@jit_unused_if_no_cuaev()
def init_cuaev_computer(self):
self.cuaev_computer = torch.classes.cuaev.CuaevComputer(self.Rcr, self.Rca, self.EtaR.flatten(), self.ShfR.flatten(), self.EtaA.flatten(), self.Zeta.flatten(), self.ShfA.flatten(), self.ShfZ.flatten(), self.num_species)
@jit_unused_if_no_cuaev()
def compute_cuaev(self, species, coordinates):
species_int = species.to(torch.int32)
aev = torch.ops.cuaev.run(coordinates, species_int, self.cuaev_computer)
return aev
@classmethod @classmethod
def cover_linearly(cls, radial_cutoff: float, angular_cutoff: float, def cover_linearly(cls, radial_cutoff: float, angular_cutoff: float,
radial_eta: float, angular_eta: float, radial_eta: float, angular_eta: float,
...@@ -505,8 +517,11 @@ class AEVComputer(torch.nn.Module): ...@@ -505,8 +517,11 @@ class AEVComputer(torch.nn.Module):
assert coordinates.shape[-1] == 3 assert coordinates.shape[-1] == 3
if self.use_cuda_extension: if self.use_cuda_extension:
assert (cell is None and pbc is None), "cuaev does not support PBC" assert (cell is None and pbc is None), "cuaev currently does not support PBC"
aev = compute_cuaev(species, coordinates, self.triu_index, self.constants(), self.num_species, None) # if use_cuda_extension is enabled after initialization
if not self.cuaev_enabled:
self.init_cuaev_computer()
aev = self.compute_cuaev(species, coordinates)
return SpeciesAEV(species, aev) return SpeciesAEV(species, aev)
if cell is None and pbc is None: if cell is None and pbc is None:
......
#include <aev.h>
#include <thrust/equal.h> #include <thrust/equal.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
...@@ -10,139 +11,6 @@ ...@@ -10,139 +11,6 @@
#define PI 3.141592653589793 #define PI 3.141592653589793
using torch::Tensor; using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
// [Computation graph for forward, backward, and double backward]
//
// backward
// force = (dE / daev) * (daev / dcoord) = g * (daev / dcoord)
//
// double backward (to do force training, the term needed is)
// dloss / dg = (dloss / dforce) * (dforce / dg) = (dloss / dforce) * (daev / dcoord)
//
//
// [Forward]
// out ^
// | ^
// ... ^
// | ^
// e n e r g y ^
// | \ ^
// aev \ ^
// / | \ ^
// radial angular params ^
// / / | ^
// dist---^ / ^
// \ / ^
// coord ^
//
// Functional relationship:
// coord <-- input
// dist(coord)
// radial(dist)
// angular(dist, coord)
// aev = concatenate(radial, angular)
// energy(aev, params)
// out(energy, ....) <-- output
//
//
// [Backward]
// dout v
// | v
// ... v
// | v
// aev params denergy aev params v
// \ | / \ | / v
// d a e v dparams v
// / \____ v
// dist dradial \ v
// \ / \ v
// ddist dist coord dangular dist coord v
// \ / / \ | / v
// \_/____/ \___|___/ v
// | __________________/ v
// | / v
// dcoord v
// | v
// ... v
// | v
// out2 v
//
// Functional relationship:
// dout <-- input
// denergy(dout)
// dparams(denergy, aev, params) <-- output
// daev(denergy, aev, params)
// dradial = slice(daev)
// dangular = slice(daev)
// ddist = radial_backward(dradial, dist) + angular_backward_dist(dangular, ...)
// = radial_backward(dradial, dist) + 0 (all contributions route to dcoord)
// = radial_backward(dradial, dist)
// dcoord = dist_backward(ddist, coord, dist) + angular_backward_coord(dangular, coord, dist)
// out2(dcoord, ...) <-- output
//
//
// [Double backward w.r.t params (i.e. force training)]
// Note: only a very limited subset of double backward is implemented
// currently it can only do force training, there is no hessian support
// not implemented terms are marked by $s
// $$$ [dparams] $$$$ ^
// \_ | __/ ^
// [ddaev] ^
// / \_____ ^
// $$$$ [ddradial] \ ^
// \ / \ ^
// [dddist] $$$$ $$$$ [ddangular] $$$$ $$$$ ^
// \ / / \ | / ^
// \_/____/ \_____|___/ ^
// | _____________________/ ^
// | / ^
// [ddcoord] ^
// | ^
// ... ^
// | ^
// [dout2] ^
//
// Functional relationship:
// dout2 <-- input
// ddcoord(dout2, ...)
// dddist = dist_doublebackward(ddcoord, coord, dist)
// ddradial = radial_doublebackward(dddist, dist)
// ddangular = angular_doublebackward(ddcord, coord, dist)
// ddaev = concatenate(ddradial, ddangular)
// dparams(ddaev, ...) <-- output
template <typename DataT, typename IndexT = int>
struct AEVScalarParams {
DataT Rcr;
DataT Rca;
IndexT radial_sublength;
IndexT radial_length;
IndexT angular_sublength;
IndexT angular_length;
IndexT num_species;
AEVScalarParams() = default;
AEVScalarParams(const torch::IValue& aev_params_ivalue) {
c10::intrusive_ptr<c10::ivalue::Tuple> aev_params_tuple_ptr = aev_params_ivalue.toTuple();
auto aev_params_tuple = aev_params_tuple_ptr->elements();
Rcr = static_cast<DataT>(aev_params_tuple[0].toDouble());
Rca = static_cast<DataT>(aev_params_tuple[1].toDouble());
radial_sublength = static_cast<IndexT>(aev_params_tuple[2].toInt());
radial_length = static_cast<IndexT>(aev_params_tuple[3].toInt());
angular_sublength = static_cast<IndexT>(aev_params_tuple[4].toInt());
angular_length = static_cast<IndexT>(aev_params_tuple[5].toInt());
num_species = static_cast<IndexT>(aev_params_tuple[6].toInt());
}
operator torch::IValue() {
return torch::IValue(std::make_tuple(
(double)Rcr, (double)Rca, radial_sublength, radial_length, angular_sublength, angular_length, num_species));
}
};
// fetch from the following matrix // fetch from the following matrix
// [[ 0, 1, 2, 3, 4], // [[ 0, 1, 2, 3, 4],
...@@ -343,7 +211,11 @@ __global__ void cuAngularAEVs( ...@@ -343,7 +211,11 @@ __global__ void cuAngularAEVs(
PairDist<DataT>* d_centralAtom, PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom, int* d_nPairsPerCenterAtom,
int* d_centerAtomStartIdx, int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params, float Rca,
int angular_length,
int angular_sublength,
int radial_length,
int num_species,
int maxnbrs_per_atom_aligned, int maxnbrs_per_atom_aligned,
int angular_length_aligned, int angular_length_aligned,
int ncentral_atoms) { int ncentral_atoms) {
...@@ -386,8 +258,6 @@ __global__ void cuAngularAEVs( ...@@ -386,8 +258,6 @@ __global__ void cuAngularAEVs(
IndexT nShfA = ShfA_t.size(0); IndexT nShfA = ShfA_t.size(0);
IndexT nShfZ = ShfZ_t.size(0); IndexT nShfZ = ShfZ_t.size(0);
DataT Rca = aev_params.Rca;
IndexT num_species = aev_params.num_species;
PairDist<DataT> d = d_centralAtom[cIdx]; PairDist<DataT> d = d_centralAtom[cIdx];
int start_idx = d_centerAtomStartIdx[cIdx]; int start_idx = d_centerAtomStartIdx[cIdx];
...@@ -397,7 +267,7 @@ __global__ void cuAngularAEVs( ...@@ -397,7 +267,7 @@ __global__ void cuAngularAEVs(
int i = d.i; int i = d.i;
int mol_idx = d.midx; int mol_idx = d.midx;
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) { for (int iaev = laneIdx; iaev < angular_length; iaev += threads_per_catom) {
saev[iaev] = 0; saev[iaev] = 0;
} }
...@@ -449,7 +319,7 @@ __global__ void cuAngularAEVs( ...@@ -449,7 +319,7 @@ __global__ void cuAngularAEVs(
DataT Rijk = (Rij + Rik) / 2; DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species); IndexT subaev_offset = angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -469,8 +339,8 @@ __global__ void cuAngularAEVs( ...@@ -469,8 +339,8 @@ __global__ void cuAngularAEVs(
} }
} }
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) { for (int iaev = laneIdx; iaev < angular_length; iaev += threads_per_catom) {
aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev]; aev_t[mol_idx][i][radial_length + iaev] = saev[iaev];
} }
} }
...@@ -496,7 +366,11 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -496,7 +366,11 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
const PairDist<DataT>* d_centralAtom, const PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom, int* d_nPairsPerCenterAtom,
int* d_centerAtomStartIdx, int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params, float Rca,
int angular_length,
int angular_sublength,
int radial_length,
int num_species,
int maxnbrs_per_atom_aligned, int maxnbrs_per_atom_aligned,
int angular_length_aligned, int angular_length_aligned,
int ncentral_atoms) { int ncentral_atoms) {
...@@ -548,8 +422,6 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -548,8 +422,6 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
IndexT nShfA = ShfA_t.size(0); IndexT nShfA = ShfA_t.size(0);
IndexT nShfZ = ShfZ_t.size(0); IndexT nShfZ = ShfZ_t.size(0);
DataT Rca = aev_params.Rca;
IndexT num_species = aev_params.num_species;
PairDist<DataT> d = d_centralAtom[cIdx]; PairDist<DataT> d = d_centralAtom[cIdx];
int start_idx = d_centerAtomStartIdx[cIdx]; int start_idx = d_centerAtomStartIdx[cIdx];
...@@ -650,7 +522,7 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -650,7 +522,7 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
DataT Rijk = (Rij + Rik) / 2; DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species); IndexT subaev_offset = angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -701,11 +573,10 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -701,11 +573,10 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
grad_vik_y *= (grad_force[mol_idx][atomk_idx][1] - grad_force[mol_idx][i][1]); grad_vik_y *= (grad_force[mol_idx][atomk_idx][1] - grad_force[mol_idx][i][1]);
grad_vik_z *= (grad_force[mol_idx][atomk_idx][2] - grad_force[mol_idx][i][2]); grad_vik_z *= (grad_force[mol_idx][atomk_idx][2] - grad_force[mol_idx][i][2]);
atomicAdd( atomicAdd(
&grad_grad_aev[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta], &grad_grad_aev[mol_idx][i][radial_length + subaev_offset + ishfr * nShfZ + itheta],
grad_vij_x + grad_vij_y + grad_vij_z + grad_vik_x + grad_vik_y + grad_vik_z); grad_vij_x + grad_vij_y + grad_vij_z + grad_vik_x + grad_vik_y + grad_vik_z);
} else { } else {
DataT grad_output_item = DataT grad_output_item = grad_output[mol_idx][i][radial_length + subaev_offset + ishfr * nShfZ + itheta];
grad_output[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta];
grad_vij_x *= grad_output_item; grad_vij_x *= grad_output_item;
grad_vij_y *= grad_output_item; grad_vij_y *= grad_output_item;
grad_vij_z *= grad_output_item; grad_vij_z *= grad_output_item;
...@@ -765,7 +636,9 @@ __global__ void cuRadialAEVs( ...@@ -765,7 +636,9 @@ __global__ void cuRadialAEVs(
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT>* d_Rij, PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params, float Rcr,
int radial_length,
int radial_sublength,
int nRadialRij) { int nRadialRij) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x; int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = gidx / THREADS_PER_RIJ; int idx = gidx / THREADS_PER_RIJ;
...@@ -786,14 +659,14 @@ __global__ void cuRadialAEVs( ...@@ -786,14 +659,14 @@ __global__ void cuRadialAEVs(
SpeciesT type_j = species_t[mol_idx][j]; SpeciesT type_j = species_t[mol_idx][j];
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5; DataT fc = 0.5 * cos(PI * Rij / Rcr) + 0.5;
for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) { for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) {
DataT ShfR = ShfR_t[ishfr]; DataT ShfR = ShfR_t[ishfr];
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc; DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc;
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], GmR); atomicAdd(&aev_t[mol_idx][i][type_j * radial_sublength + ishfr], GmR);
} }
} }
...@@ -808,7 +681,9 @@ __global__ void cuRadialAEVs_backward_or_doublebackward( ...@@ -808,7 +681,9 @@ __global__ void cuRadialAEVs_backward_or_doublebackward(
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits>
grad_dist, // ddist for backward, dddist for double backward grad_dist, // ddist for backward, dddist for double backward
const PairDist<DataT>* d_Rij, const PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params, float Rcr,
int radial_length,
int radial_sublength,
int nRadialRij) { int nRadialRij) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x; int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = gidx / THREADS_PER_RIJ; int idx = gidx / THREADS_PER_RIJ;
...@@ -829,8 +704,8 @@ __global__ void cuRadialAEVs_backward_or_doublebackward( ...@@ -829,8 +704,8 @@ __global__ void cuRadialAEVs_backward_or_doublebackward(
SpeciesT type_j = species_t[mol_idx][j]; SpeciesT type_j = species_t[mol_idx][j];
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5; DataT fc = 0.5 * cos(PI * Rij / Rcr) + 0.5;
DataT fc_grad = -0.5 * (PI / aev_params.Rcr) * sin(PI * Rij / aev_params.Rcr); DataT fc_grad = -0.5 * (PI / Rcr) * sin(PI * Rij / Rcr);
DataT upstream_grad; DataT upstream_grad;
if (is_double_backward) { if (is_double_backward) {
...@@ -845,9 +720,9 @@ __global__ void cuRadialAEVs_backward_or_doublebackward( ...@@ -845,9 +720,9 @@ __global__ void cuRadialAEVs_backward_or_doublebackward(
DataT jacobian = GmR_grad * fc + GmR * fc_grad; DataT jacobian = GmR_grad * fc + GmR * fc_grad;
if (is_double_backward) { if (is_double_backward) {
atomicAdd(&grad_aev[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], upstream_grad * jacobian); atomicAdd(&grad_aev[mol_idx][i][type_j * radial_sublength + ishfr], upstream_grad * jacobian);
} else { } else {
upstream_grad = grad_aev[mol_idx][i][type_j * aev_params.radial_sublength + ishfr]; upstream_grad = grad_aev[mol_idx][i][type_j * radial_sublength + ishfr];
atomicAdd(&grad_dist[idx], upstream_grad * jacobian); atomicAdd(&grad_dist[idx], upstream_grad * jacobian);
} }
} }
...@@ -952,68 +827,26 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream ...@@ -952,68 +827,26 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream
return maxVal; return maxVal;
} }
struct Result {
Tensor aev_t;
AEVScalarParams<float> aev_params;
Tensor tensor_Rij;
Tensor tensor_radialRij;
Tensor tensor_angularRij;
int total_natom_pairs;
int nRadialRij;
int nAngularRij;
Tensor tensor_centralAtom;
Tensor tensor_numPairsPerCenterAtom;
Tensor tensor_centerAtomStartIdx;
int maxnbrs_per_atom_aligned;
int angular_length_aligned;
int ncenter_atoms;
};
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1 // NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
template <typename ScalarRealT = float> Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const AEVScalarParams& aev_params) {
Result cuaev_forward(
const Tensor& coordinates_t,
const Tensor& species_t,
double Rcr_,
double Rca_,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
int64_t num_species_) {
TORCH_CHECK( TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type"); (species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
TORCH_CHECK( TORCH_CHECK(
EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1, aev_params.EtaR_t.size(0) == 1 || aev_params.EtaA_t.size(0) == 1 || aev_params.Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified " "cuda extension is currently not supported for the specified "
"configuration"); "configuration");
ScalarRealT Rcr = Rcr_; float Rcr = aev_params.Rcr;
ScalarRealT Rca = Rca_; float Rca = aev_params.Rca;
int num_species = num_species_;
const int n_molecules = species_t.size(0); const int n_molecules = species_t.size(0);
const int max_natoms_per_mol = species_t.size(1); const int max_natoms_per_mol = species_t.size(1);
AEVScalarParams<float> aev_params;
aev_params.Rca = Rca;
aev_params.Rcr = Rcr;
aev_params.num_species = num_species;
aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0);
aev_params.radial_length = aev_params.radial_sublength * num_species;
aev_params.angular_sublength = EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0);
aev_params.angular_length = aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
int aev_length = aev_params.radial_length + aev_params.angular_length; int aev_length = aev_params.radial_length + aev_params.angular_length;
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options()); auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options());
if (species_t.numel() == 0) { if (species_t.numel() == 0) {
return {aev_t, aev_params, Tensor(), Tensor(), Tensor(), 0, 0, 0}; return {
aev_t, Tensor(), Tensor(), Tensor(), 0, 0, 0, Tensor(), Tensor(), Tensor(), 0, 0, 0, coordinates_t, species_t};
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -1076,11 +909,13 @@ Result cuaev_forward( ...@@ -1076,11 +909,13 @@ Result cuaev_forward(
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size; int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>( cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
aev_params, aev_params.Rcr,
aev_params.radial_length,
aev_params.radial_sublength,
nRadialRij); nRadialRij);
// reuse buffer allocated for all Rij // reuse buffer allocated for all Rij
...@@ -1111,6 +946,7 @@ Result cuaev_forward( ...@@ -1111,6 +946,7 @@ Result cuaev_forward(
int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr(); int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr();
cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream); cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream);
{ {
const int nthreads_per_catom = 32; const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size; const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
...@@ -1132,22 +968,25 @@ Result cuaev_forward( ...@@ -1132,22 +968,25 @@ Result cuaev_forward(
cuAngularAEVs<<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>( cuAngularAEVs<<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_angularRij,
d_centralAtom, d_centralAtom,
d_numPairsPerCenterAtom, d_numPairsPerCenterAtom,
d_centerAtomStartIdx, d_centerAtomStartIdx,
aev_params, aev_params.Rca,
aev_params.angular_length,
aev_params.angular_sublength,
aev_params.radial_length,
aev_params.num_species,
maxnbrs_per_atom_aligned, maxnbrs_per_atom_aligned,
angular_length_aligned, angular_length_aligned,
ncenter_atoms); ncenter_atoms);
return {aev_t, return {aev_t,
aev_params,
tensor_Rij, tensor_Rij,
tensor_radialRij, tensor_radialRij,
tensor_angularRij, tensor_angularRij,
...@@ -1159,71 +998,54 @@ Result cuaev_forward( ...@@ -1159,71 +998,54 @@ Result cuaev_forward(
tensor_centerAtomStartIdx, tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned, maxnbrs_per_atom_aligned,
angular_length_aligned, angular_length_aligned,
ncenter_atoms}; ncenter_atoms,
coordinates_t,
species_t};
} }
} }
Tensor cuaev_backward( Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_params, const Result& result) {
const Tensor& grad_output,
const Tensor& coordinates_t,
const Tensor& species_t,
const AEVScalarParams<float>& aev_params,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
const Tensor& tensor_Rij,
int total_natom_pairs,
const Tensor& tensor_radialRij,
int nRadialRij,
const Tensor& tensor_angularRij,
int nAngularRij,
const Tensor& tensor_centralAtom,
const Tensor& tensor_numPairsPerCenterAtom,
const Tensor& tensor_centerAtomStartIdx,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncenter_atoms) {
using namespace torch::indexing; using namespace torch::indexing;
Tensor coordinates_t = result.coordinates_t;
Tensor species_t = result.species_t;
const int n_molecules = coordinates_t.size(0); const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1); const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3] auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3]
auto grad_output_radial = grad_output.index({Ellipsis, Slice(None, aev_params.radial_length)}); // [2, 5, 64]
auto grad_output_angular = grad_output.index({Ellipsis, Slice(aev_params.radial_length, None)}); // [2, 5, 320]
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr(); PairDist<float>* d_Rij = (PairDist<float>*)result.tensor_Rij.data_ptr();
PairDist<float>* d_radialRij = (PairDist<float>*)tensor_radialRij.data_ptr(); PairDist<float>* d_radialRij = (PairDist<float>*)result.tensor_radialRij.data_ptr();
PairDist<float>* d_angularRij = (PairDist<float>*)tensor_angularRij.data_ptr(); PairDist<float>* d_angularRij = (PairDist<float>*)result.tensor_angularRij.data_ptr();
PairDist<float>* d_centralAtom = (PairDist<float>*)tensor_centralAtom.data_ptr(); PairDist<float>* d_centralAtom = (PairDist<float>*)result.tensor_centralAtom.data_ptr();
int* d_numPairsPerCenterAtom = (int*)tensor_numPairsPerCenterAtom.data_ptr(); int* d_numPairsPerCenterAtom = (int*)result.tensor_numPairsPerCenterAtom.data_ptr();
int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr(); int* d_centerAtomStartIdx = (int*)result.tensor_centerAtomStartIdx.data_ptr();
Tensor grad_radial_dist = torch::zeros(nRadialRij, coordinates_t.options().requires_grad(false)); Tensor grad_radial_dist = torch::zeros(result.nRadialRij, coordinates_t.options().requires_grad(false));
int block_size = 64; int block_size = 64;
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size; int nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<false, int, float, 8><<<nblocks, block_size, 0, stream>>>( cuRadialAEVs_backward_or_doublebackward<false, int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
aev_params, aev_params.Rcr,
nRadialRij); aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd // For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (nRadialRij + block_size - 1) / block_size; nblocks = (result.nRadialRij + block_size - 1) / block_size;
pairwiseDistance_backward_or_doublebackward<false><<<nblocks, block_size, 0, stream>>>( pairwiseDistance_backward_or_doublebackward<false><<<nblocks, block_size, 0, stream>>>(
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
nRadialRij); result.nRadialRij);
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
...@@ -1238,55 +1060,40 @@ Tensor cuaev_backward( ...@@ -1238,55 +1060,40 @@ Tensor cuaev_backward(
block_size = 32; block_size = 32;
const int nthreads_per_catom = 32; const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size; const int nblocks_angAEV = (result.ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom); int smem_size_aligned = smem_size(result.maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
Tensor grad_angular_coord = torch::zeros({nAngularRij, 3}, coordinates_t.options().requires_grad(false)); Tensor grad_angular_coord = torch::zeros({result.nAngularRij, 3}, coordinates_t.options().requires_grad(false));
cuAngularAEVs_backward_or_doublebackward<false><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>( cuAngularAEVs_backward_or_doublebackward<false><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_angularRij,
d_centralAtom, d_centralAtom,
d_numPairsPerCenterAtom, d_numPairsPerCenterAtom,
d_centerAtomStartIdx, d_centerAtomStartIdx,
aev_params, aev_params.Rca,
maxnbrs_per_atom_aligned, aev_params.angular_length,
angular_length_aligned, aev_params.angular_sublength,
ncenter_atoms); aev_params.radial_length,
aev_params.num_species,
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
return grad_coord; return grad_coord;
} }
Tensor cuaev_double_backward( Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& aev_params, const Result& result) {
const Tensor& grad_force,
const Tensor& coordinates_t,
const Tensor& species_t,
const AEVScalarParams<float>& aev_params,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
const Tensor& tensor_Rij,
int total_natom_pairs,
const Tensor& tensor_radialRij,
int nRadialRij,
const Tensor& tensor_angularRij,
int nAngularRij,
const Tensor& tensor_centralAtom,
const Tensor& tensor_numPairsPerCenterAtom,
const Tensor& tensor_centerAtomStartIdx,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncenter_atoms) {
using namespace torch::indexing; using namespace torch::indexing;
Tensor coordinates_t = result.coordinates_t;
Tensor species_t = result.species_t;
const int n_molecules = coordinates_t.size(0); const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1); const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -1297,34 +1104,36 @@ Tensor cuaev_double_backward( ...@@ -1297,34 +1104,36 @@ Tensor cuaev_double_backward(
{coordinates_t.size(0), coordinates_t.size(1), aev_length}, {coordinates_t.size(0), coordinates_t.size(1), aev_length},
coordinates_t.options().requires_grad(false)); // [2, 5, 384] coordinates_t.options().requires_grad(false)); // [2, 5, 384]
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr(); PairDist<float>* d_Rij = (PairDist<float>*)result.tensor_Rij.data_ptr();
PairDist<float>* d_radialRij = (PairDist<float>*)tensor_radialRij.data_ptr(); PairDist<float>* d_radialRij = (PairDist<float>*)result.tensor_radialRij.data_ptr();
PairDist<float>* d_angularRij = (PairDist<float>*)tensor_angularRij.data_ptr(); PairDist<float>* d_angularRij = (PairDist<float>*)result.tensor_angularRij.data_ptr();
PairDist<float>* d_centralAtom = (PairDist<float>*)tensor_centralAtom.data_ptr(); PairDist<float>* d_centralAtom = (PairDist<float>*)result.tensor_centralAtom.data_ptr();
int* d_numPairsPerCenterAtom = (int*)tensor_numPairsPerCenterAtom.data_ptr(); int* d_numPairsPerCenterAtom = (int*)result.tensor_numPairsPerCenterAtom.data_ptr();
int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr(); int* d_centerAtomStartIdx = (int*)result.tensor_centerAtomStartIdx.data_ptr();
auto grad_force_coord_Rij = torch::zeros({nRadialRij}, coordinates_t.options().requires_grad(false)); auto grad_force_coord_Rij = torch::zeros({result.nRadialRij}, coordinates_t.options().requires_grad(false));
int block_size = 64; int block_size = 64;
int nblocks = (nRadialRij + block_size - 1) / block_size; int nblocks = (result.nRadialRij + block_size - 1) / block_size;
pairwiseDistance_backward_or_doublebackward<true><<<nblocks, block_size, 0, stream>>>( pairwiseDistance_backward_or_doublebackward<true><<<nblocks, block_size, 0, stream>>>(
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
nRadialRij); result.nRadialRij);
nblocks = (nRadialRij * 8 + block_size - 1) / block_size; nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>( cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), grad_force_coord_Rij.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
aev_params, aev_params.Rcr,
nRadialRij); aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
...@@ -1339,202 +1148,30 @@ Tensor cuaev_double_backward( ...@@ -1339,202 +1148,30 @@ Tensor cuaev_double_backward(
block_size = 32; block_size = 32;
const int nthreads_per_catom = 32; const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size; const int nblocks_angAEV = (result.ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom); int smem_size_aligned = smem_size(result.maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
cuAngularAEVs_backward_or_doublebackward<true><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>( cuAngularAEVs_backward_or_doublebackward<true><<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), aev_params.Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_grad_aev.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_angularRij,
d_centralAtom, d_centralAtom,
d_numPairsPerCenterAtom, d_numPairsPerCenterAtom,
d_centerAtomStartIdx, d_centerAtomStartIdx,
aev_params, aev_params.Rca,
maxnbrs_per_atom_aligned, aev_params.angular_length,
angular_length_aligned, aev_params.angular_sublength,
ncenter_atoms); aev_params.radial_length,
aev_params.num_species,
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
return grad_grad_aev; return grad_grad_aev;
} }
class CuaevDoubleAutograd : public torch::autograd::Function<CuaevDoubleAutograd> {
public:
static Tensor forward(AutogradContext* ctx, Tensor grad_e_aev, AutogradContext* prectx) {
auto saved = prectx->get_saved_variables();
auto coordinates_t = saved[0], species_t = saved[1];
auto tensor_Rij = saved[2], tensor_radialRij = saved[3], tensor_angularRij = saved[4];
auto EtaR_t = saved[5], ShfR_t = saved[6], EtaA_t = saved[7], Zeta_t = saved[8], ShfA_t = saved[9],
ShfZ_t = saved[10];
auto tensor_centralAtom = saved[11], tensor_numPairsPerCenterAtom = saved[12],
tensor_centerAtomStartIdx = saved[13];
AEVScalarParams<float> aev_params(prectx->saved_data["aev_params"]);
c10::List<int64_t> int_list = prectx->saved_data["int_list"].toIntList();
int total_natom_pairs = int_list[0], nRadialRij = int_list[1], nAngularRij = int_list[2];
int maxnbrs_per_atom_aligned = int_list[3], angular_length_aligned = int_list[4];
int ncenter_atoms = int_list[5];
if (grad_e_aev.requires_grad()) {
ctx->save_for_backward({coordinates_t,
species_t,
tensor_Rij,
tensor_radialRij,
tensor_angularRij,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{
total_natom_pairs, nRadialRij, nAngularRij, maxnbrs_per_atom_aligned, angular_length_aligned, ncenter_atoms};
}
Tensor grad_coord = cuaev_backward(
grad_e_aev,
coordinates_t,
species_t,
aev_params,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
tensor_Rij,
total_natom_pairs,
tensor_radialRij,
nRadialRij,
tensor_angularRij,
nAngularRij,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return grad_coord;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_force = grad_outputs[0];
auto saved = ctx->get_saved_variables();
auto coordinates_t = saved[0], species_t = saved[1];
auto tensor_Rij = saved[2], tensor_radialRij = saved[3], tensor_angularRij = saved[4];
auto EtaR_t = saved[5], ShfR_t = saved[6], EtaA_t = saved[7], Zeta_t = saved[8], ShfA_t = saved[9],
ShfZ_t = saved[10];
auto tensor_centralAtom = saved[11], tensor_numPairsPerCenterAtom = saved[12],
tensor_centerAtomStartIdx = saved[13];
AEVScalarParams<float> aev_params(ctx->saved_data["aev_params"]);
c10::List<int64_t> int_list = ctx->saved_data["int_list"].toIntList();
int total_natom_pairs = int_list[0], nRadialRij = int_list[1], nAngularRij = int_list[2];
int maxnbrs_per_atom_aligned = int_list[3], angular_length_aligned = int_list[4];
int ncenter_atoms = int_list[5];
Tensor grad_grad_aev = cuaev_double_backward(
grad_force,
coordinates_t,
species_t,
aev_params,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
tensor_Rij,
total_natom_pairs,
tensor_radialRij,
nRadialRij,
tensor_angularRij,
nAngularRij,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return {grad_grad_aev, torch::Tensor()};
}
};
#define AEV_INPUT \
const Tensor &coordinates_t, const Tensor &species_t, double Rcr_, double Rca_, const Tensor &EtaR_t, \
const Tensor &ShfR_t, const Tensor &EtaA_t, const Tensor &Zeta_t, const Tensor &ShfA_t, const Tensor &ShfZ_t, \
int64_t num_species_
Tensor cuaev_cuda(AEV_INPUT) {
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
return res.aev_t;
}
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
public:
static Tensor forward(AutogradContext* ctx, AEV_INPUT) {
at::AutoNonVariableTypeMode g;
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
if (coordinates_t.requires_grad()) {
ctx->save_for_backward({coordinates_t,
species_t,
res.tensor_Rij,
res.tensor_radialRij,
res.tensor_angularRij,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
res.tensor_centralAtom,
res.tensor_numPairsPerCenterAtom,
res.tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = res.aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{res.total_natom_pairs,
res.nRadialRij,
res.nAngularRij,
res.maxnbrs_per_atom_aligned,
res.angular_length_aligned,
res.ncenter_atoms};
}
return res.aev_t;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_coord = CuaevDoubleAutograd::apply(grad_outputs[0], ctx);
return {
grad_coord, Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()};
}
};
Tensor cuaev_autograd(AEV_INPUT) {
return CuaevAutograd::apply(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
}
TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", cuaev_cuda);
}
TORCH_LIBRARY_IMPL(cuaev, CUDA, m) {
m.impl("cuComputeAEV", cuaev_cuda);
}
TORCH_LIBRARY_IMPL(cuaev, Autograd, m) {
m.impl("cuComputeAEV", cuaev_autograd);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
#pragma once
#include <c10/cuda/CUDACachingAllocator.h>
#include <torch/extension.h>
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
// [Computation graph for forward, backward, and double backward]
//
// backward
// force = (dE / daev) * (daev / dcoord) = g * (daev / dcoord)
//
// double backward (to do force training, the term needed is)
// dloss / dg = (dloss / dforce) * (dforce / dg) = (dloss / dforce) * (daev / dcoord)
//
//
// [Forward]
// out ^
// | ^
// ... ^
// | ^
// e n e r g y ^
// | \ ^
// aev \ ^
// / | \ ^
// radial angular params ^
// / / | ^
// dist---^ / ^
// \ / ^
// coord ^
//
// Functional relationship:
// coord <-- input
// dist(coord)
// radial(dist)
// angular(dist, coord)
// aev = concatenate(radial, angular)
// energy(aev, params)
// out(energy, ....) <-- output
//
//
// [Backward]
// dout v
// | v
// ... v
// | v
// aev params denergy aev params v
// \ | / \ | / v
// d a e v dparams v
// / \____ v
// dist dradial \ v
// \ / \ v
// ddist dist coord dangular dist coord v
// \ / / \ | / v
// \_/____/ \___|___/ v
// | __________________/ v
// | / v
// dcoord v
// | v
// ... v
// | v
// out2 v
//
// Functional relationship:
// dout <-- input
// denergy(dout)
// dparams(denergy, aev, params) <-- output
// daev(denergy, aev, params)
// dradial = slice(daev)
// dangular = slice(daev)
// ddist = radial_backward(dradial, dist) + angular_backward_dist(dangular, ...)
// = radial_backward(dradial, dist) + 0 (all contributions route to dcoord)
// = radial_backward(dradial, dist)
// dcoord = dist_backward(ddist, coord, dist) + angular_backward_coord(dangular, coord, dist)
// out2(dcoord, ...) <-- output
//
//
// [Double backward w.r.t params (i.e. force training)]
// Note: only a very limited subset of double backward is implemented
// currently it can only do force training, there is no hessian support
// not implemented terms are marked by $s
// $$$ [dparams] $$$$ ^
// \_ | __/ ^
// [ddaev] ^
// / \_____ ^
// $$$$ [ddradial] \ ^
// \ / \ ^
// [dddist] $$$$ $$$$ [ddangular] $$$$ $$$$ ^
// \ / / \ | / ^
// \_/____/ \_____|___/ ^
// | _____________________/ ^
// | / ^
// [ddcoord] ^
// | ^
// ... ^
// | ^
// [dout2] ^
//
// Functional relationship:
// dout2 <-- input
// ddcoord(dout2, ...)
// dddist = dist_doublebackward(ddcoord, coord, dist)
// ddradial = radial_doublebackward(dddist, dist)
// ddangular = angular_doublebackward(ddcord, coord, dist)
// ddaev = concatenate(ddradial, ddangular)
// dparams(ddaev, ...) <-- output
struct AEVScalarParams {
float Rcr;
float Rca;
int radial_sublength;
int radial_length;
int angular_sublength;
int angular_length;
int num_species;
Tensor EtaR_t;
Tensor ShfR_t;
Tensor EtaA_t;
Tensor Zeta_t;
Tensor ShfA_t;
Tensor ShfZ_t;
AEVScalarParams(
float Rcr_,
float Rca_,
Tensor EtaR_t_,
Tensor ShfR_t_,
Tensor EtaA_t_,
Tensor Zeta_t_,
Tensor ShfA_t_,
Tensor ShfZ_t_,
int num_species_);
};
struct Result {
Tensor aev_t;
Tensor tensor_Rij;
Tensor tensor_radialRij;
Tensor tensor_angularRij;
int total_natom_pairs;
int nRadialRij;
int nAngularRij;
Tensor tensor_centralAtom;
Tensor tensor_numPairsPerCenterAtom;
Tensor tensor_centerAtomStartIdx;
int maxnbrs_per_atom_aligned;
int angular_length_aligned;
int ncenter_atoms;
Tensor coordinates_t;
Tensor species_t;
Result(
Tensor aev_t_,
Tensor tensor_Rij_,
Tensor tensor_radialRij_,
Tensor tensor_angularRij_,
int64_t total_natom_pairs_,
int64_t nRadialRij_,
int64_t nAngularRij_,
Tensor tensor_centralAtom_,
Tensor tensor_numPairsPerCenterAtom_,
Tensor tensor_centerAtomStartIdx_,
int64_t maxnbrs_per_atom_aligned_,
int64_t angular_length_aligned_,
int64_t ncenter_atoms_,
Tensor coordinates_t_,
Tensor species_t_);
Result(tensor_list tensors);
operator tensor_list() {
return {Tensor(), // aev_t got removed
tensor_Rij,
tensor_radialRij,
tensor_angularRij,
torch::tensor(total_natom_pairs),
torch::tensor(nRadialRij),
torch::tensor(nAngularRij),
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
torch::tensor(maxnbrs_per_atom_aligned),
torch::tensor(angular_length_aligned),
torch::tensor(ncenter_atoms),
coordinates_t,
species_t};
}
};
// cuda kernels
Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const AEVScalarParams& aev_params);
Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_params, const Result& result);
Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& aev_params, const Result& result);
// CuaevComputer
// Only keep one copy of aev parameters
struct CuaevComputer : torch::CustomClassHolder {
AEVScalarParams aev_params;
CuaevComputer(
double Rcr,
double Rca,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
int64_t num_species);
Result forward(const Tensor& coordinates_t, const Tensor& species_t) {
return cuaev_forward(coordinates_t, species_t, aev_params);
}
Tensor backward(const Tensor& grad_e_aev, const Result& result) {
return cuaev_backward(grad_e_aev, aev_params, result); // force
}
Tensor double_backward(const Tensor& grad_force, const Result& result) {
return cuaev_double_backward(grad_force, aev_params, result); // grad_grad_aev
}
};
// Autograd functions
class CuaevDoubleAutograd : public torch::autograd::Function<CuaevDoubleAutograd> {
public:
static Tensor forward(
AutogradContext* ctx,
Tensor grad_e_aev,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer,
tensor_list result_tensors);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
public:
static Tensor forward(
AutogradContext* ctx,
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
#include <aev.h>
#include <torch/extension.h>
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
AEVScalarParams::AEVScalarParams(
float Rcr_,
float Rca_,
Tensor EtaR_t_,
Tensor ShfR_t_,
Tensor EtaA_t_,
Tensor Zeta_t_,
Tensor ShfA_t_,
Tensor ShfZ_t_,
int num_species_)
: Rcr(Rcr_),
Rca(Rca_),
radial_sublength(EtaR_t_.size(0) * ShfR_t_.size(0)),
angular_sublength(EtaA_t_.size(0) * Zeta_t_.size(0) * ShfA_t_.size(0) * ShfZ_t_.size(0)),
num_species(num_species_),
EtaR_t(EtaR_t_),
ShfR_t(ShfR_t_),
EtaA_t(EtaA_t_),
Zeta_t(Zeta_t_),
ShfA_t(ShfA_t_),
ShfZ_t(ShfZ_t_) {
radial_length = radial_sublength * num_species;
angular_length = angular_sublength * (num_species * (num_species + 1) / 2);
}
Result::Result(
Tensor aev_t_,
Tensor tensor_Rij_,
Tensor tensor_radialRij_,
Tensor tensor_angularRij_,
int64_t total_natom_pairs_,
int64_t nRadialRij_,
int64_t nAngularRij_,
Tensor tensor_centralAtom_,
Tensor tensor_numPairsPerCenterAtom_,
Tensor tensor_centerAtomStartIdx_,
int64_t maxnbrs_per_atom_aligned_,
int64_t angular_length_aligned_,
int64_t ncenter_atoms_,
Tensor coordinates_t_,
Tensor species_t_)
: aev_t(aev_t_),
tensor_Rij(tensor_Rij_),
tensor_radialRij(tensor_radialRij_),
tensor_angularRij(tensor_angularRij_),
total_natom_pairs(total_natom_pairs_),
nRadialRij(nRadialRij_),
nAngularRij(nAngularRij_),
tensor_centralAtom(tensor_centralAtom_),
tensor_numPairsPerCenterAtom(tensor_numPairsPerCenterAtom_),
tensor_centerAtomStartIdx(tensor_centerAtomStartIdx_),
maxnbrs_per_atom_aligned(maxnbrs_per_atom_aligned_),
angular_length_aligned(angular_length_aligned_),
ncenter_atoms(ncenter_atoms_),
coordinates_t(coordinates_t_),
species_t(species_t_) {}
Result::Result(tensor_list tensors)
: aev_t(tensors[0]), // aev_t will be a undefined tensor
tensor_Rij(tensors[1]),
tensor_radialRij(tensors[2]),
tensor_angularRij(tensors[3]),
total_natom_pairs(tensors[4].item<int>()),
nRadialRij(tensors[5].item<int>()),
nAngularRij(tensors[6].item<int>()),
tensor_centralAtom(tensors[7]),
tensor_numPairsPerCenterAtom(tensors[8]),
tensor_centerAtomStartIdx(tensors[9]),
maxnbrs_per_atom_aligned(tensors[10].item<int>()),
angular_length_aligned(tensors[11].item<int>()),
ncenter_atoms(tensors[12].item<int>()),
coordinates_t(tensors[13]),
species_t(tensors[14]) {}
CuaevComputer::CuaevComputer(
double Rcr,
double Rca,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
int64_t num_species)
: aev_params(Rcr, Rca, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species) {}
Tensor CuaevDoubleAutograd::forward(
AutogradContext* ctx,
Tensor grad_e_aev,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer,
tensor_list result_tensors) {
Tensor grad_coord = cuaev_computer->backward(grad_e_aev, result_tensors);
if (grad_e_aev.requires_grad()) {
ctx->saved_data["cuaev_computer"] = cuaev_computer;
ctx->save_for_backward(result_tensors);
}
return grad_coord;
}
tensor_list CuaevDoubleAutograd::backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_force = grad_outputs[0];
torch::intrusive_ptr<CuaevComputer> cuaev_computer = ctx->saved_data["cuaev_computer"].toCustomClass<CuaevComputer>();
Tensor grad_grad_aev = cuaev_computer->double_backward(grad_force, ctx->get_saved_variables());
return {grad_grad_aev, torch::Tensor(), torch::Tensor()};
}
Tensor CuaevAutograd::forward(
AutogradContext* ctx,
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
at::AutoNonVariableTypeMode g;
Result result = cuaev_computer->forward(coordinates_t, species_t);
if (coordinates_t.requires_grad()) {
ctx->saved_data["cuaev_computer"] = cuaev_computer;
ctx->save_for_backward(result);
}
return result.aev_t;
}
tensor_list CuaevAutograd::backward(AutogradContext* ctx, tensor_list grad_outputs) {
torch::intrusive_ptr<CuaevComputer> cuaev_computer = ctx->saved_data["cuaev_computer"].toCustomClass<CuaevComputer>();
tensor_list result_tensors = ctx->get_saved_variables();
Tensor grad_coord = CuaevDoubleAutograd::apply(grad_outputs[0], cuaev_computer, result_tensors);
return {grad_coord, Tensor(), Tensor()};
}
Tensor run_only_forward(
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
Result result = cuaev_computer->forward(coordinates_t, species_t);
return result.aev_t;
}
Tensor run_autograd(
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
return CuaevAutograd::apply(coordinates_t, species_t, cuaev_computer);
}
TORCH_LIBRARY(cuaev, m) {
m.class_<CuaevComputer>("CuaevComputer")
.def(torch::init<double, double, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int64_t>());
m.def("run", run_only_forward);
}
TORCH_LIBRARY_IMPL(cuaev, CUDA, m) {
m.impl("run", run_only_forward);
}
TORCH_LIBRARY_IMPL(cuaev, Autograd, m) {
m.impl("run", run_autograd);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment