Unverified Commit c4f5d3b4 authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

CUAEV refactor to initialize aev parameter only once (#575)



* init

* separate to h cpp cu files

* works

* fix

* clean

* update

* fix

* only keep one copy of Result

* const result&

* move declaration to header

* fix jit

* clean

* fix

* fix jit

* fix ci test

* clean

* revert to autograd in charge of it's result

* update

* fix memory issue of result.aev_t

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* save

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* update

* Update torchani/cuaev/cuaev.cpp
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/cuaev.cpp
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* Update torchani/cuaev/aev.h
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>

* remove aev_t after forward

* remove release function

* fix

* fix jit

* update
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 191cf2d1
import os import os
import glob
import subprocess import subprocess
from setuptools import setup, find_packages from setuptools import setup, find_packages
from distutils import log from distutils import log
...@@ -91,11 +90,12 @@ def cuda_extension(build_all=False): ...@@ -91,11 +90,12 @@ def cuda_extension(build_all=False):
nvcc_args.append("-gencode=arch=compute_86,code=sm_86") nvcc_args.append("-gencode=arch=compute_86,code=sm_86")
print("nvcc_args: ", nvcc_args) print("nvcc_args: ", nvcc_args)
print('-' * 75) print('-' * 75)
include_dirs = [*maybe_download_cub(), os.path.abspath("torchani/cuaev/")]
return CUDAExtension( return CUDAExtension(
name='torchani.cuaev', name='torchani.cuaev',
pkg='torchani.cuaev', pkg='torchani.cuaev',
sources=glob.glob('torchani/cuaev/*.cu'), sources=["torchani/cuaev/cuaev.cpp", "torchani/cuaev/aev.cu"],
include_dirs=maybe_download_cub(), include_dirs=include_dirs,
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args}) extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args})
......
...@@ -18,9 +18,10 @@ class TestCUAEVNoGPU(TestCase): ...@@ -18,9 +18,10 @@ class TestCUAEVNoGPU(TestCase):
def testSimple(self): def testSimple(self):
def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int): def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int):
return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) cuaev_computer = torch.classes.cuaev.CuaevComputer(Rcr, Rca, EtaR.flatten(), ShfR.flatten(), EtaA.flatten(), Zeta.flatten(), ShfA.flatten(), ShfZ.flatten(), num_species)
return torch.ops.cuaev.run(coordinates, species, cuaev_computer)
s = torch.jit.script(f) s = torch.jit.script(f)
self.assertIn("cuaev::cuComputeAEV", str(s.graph)) self.assertIn("cuaev::run", str(s.graph))
def testAEVComputer(self): def testAEVComputer(self):
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
...@@ -31,7 +32,7 @@ class TestCUAEVNoGPU(TestCase): ...@@ -31,7 +32,7 @@ class TestCUAEVNoGPU(TestCase):
# Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU # Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU
species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4) species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4)
coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5) coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5)
self.assertIn("cuaev::cuComputeAEV", str(s.graph_for((species, coordinates)))) self.assertIn("cuaev::run", str(s.graph_for((species, coordinates))))
@skipIfNoGPU @skipIfNoGPU
......
...@@ -323,18 +323,12 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor, ...@@ -323,18 +323,12 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
def compute_cuaev(species: Tensor, coordinates: Tensor, triu_index: Tensor, def jit_unused_if_no_cuaev(condition=has_cuaev):
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor], def decorator(func):
num_species: int, cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor: if not condition:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants return torch.jit.unused(func)
return func
assert cell_shifts is None, "Current implementation of cuaev does not support pbc." return decorator
species_int = species.to(torch.int32)
return torch.ops.cuaev.cuComputeAEV(coordinates, species_int, Rcr, Rca, EtaR.flatten(), ShfR.flatten(), EtaA.flatten(), Zeta.flatten(), ShfA.flatten(), ShfZ.flatten(), num_species)
if not has_cuaev:
compute_cuaev = torch.jit.unused(compute_cuaev)
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
...@@ -421,6 +415,24 @@ class AEVComputer(torch.nn.Module): ...@@ -421,6 +415,24 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('default_cell', default_cell) self.register_buffer('default_cell', default_cell)
self.register_buffer('default_shifts', default_shifts) self.register_buffer('default_shifts', default_shifts)
# Should create only when use_cuda_extension is True.
# However jit needs to know cuaev_computer's Type even when use_cuda_extension is False, because it is enabled when cuaev is available
if has_cuaev:
self.init_cuaev_computer()
# When has_cuaev is true, and use_cuda_extension is false, and user enable use_cuda_extension afterwards,
# then another init_cuaev_computer will be needed
self.cuaev_enabled = True if self.use_cuda_extension else False
@jit_unused_if_no_cuaev()
def init_cuaev_computer(self):
self.cuaev_computer = torch.classes.cuaev.CuaevComputer(self.Rcr, self.Rca, self.EtaR.flatten(), self.ShfR.flatten(), self.EtaA.flatten(), self.Zeta.flatten(), self.ShfA.flatten(), self.ShfZ.flatten(), self.num_species)
@jit_unused_if_no_cuaev()
def compute_cuaev(self, species, coordinates):
species_int = species.to(torch.int32)
aev = torch.ops.cuaev.run(coordinates, species_int, self.cuaev_computer)
return aev
@classmethod @classmethod
def cover_linearly(cls, radial_cutoff: float, angular_cutoff: float, def cover_linearly(cls, radial_cutoff: float, angular_cutoff: float,
radial_eta: float, angular_eta: float, radial_eta: float, angular_eta: float,
...@@ -505,8 +517,11 @@ class AEVComputer(torch.nn.Module): ...@@ -505,8 +517,11 @@ class AEVComputer(torch.nn.Module):
assert coordinates.shape[-1] == 3 assert coordinates.shape[-1] == 3
if self.use_cuda_extension: if self.use_cuda_extension:
assert (cell is None and pbc is None), "cuaev does not support PBC" assert (cell is None and pbc is None), "cuaev currently does not support PBC"
aev = compute_cuaev(species, coordinates, self.triu_index, self.constants(), self.num_species, None) # if use_cuda_extension is enabled after initialization
if not self.cuaev_enabled:
self.init_cuaev_computer()
aev = self.compute_cuaev(species, coordinates)
return SpeciesAEV(species, aev) return SpeciesAEV(species, aev)
if cell is None and pbc is None: if cell is None and pbc is None:
......
This diff is collapsed.
#pragma once
#include <c10/cuda/CUDACachingAllocator.h>
#include <torch/extension.h>
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
// [Computation graph for forward, backward, and double backward]
//
// backward
// force = (dE / daev) * (daev / dcoord) = g * (daev / dcoord)
//
// double backward (to do force training, the term needed is)
// dloss / dg = (dloss / dforce) * (dforce / dg) = (dloss / dforce) * (daev / dcoord)
//
//
// [Forward]
// out ^
// | ^
// ... ^
// | ^
// e n e r g y ^
// | \ ^
// aev \ ^
// / | \ ^
// radial angular params ^
// / / | ^
// dist---^ / ^
// \ / ^
// coord ^
//
// Functional relationship:
// coord <-- input
// dist(coord)
// radial(dist)
// angular(dist, coord)
// aev = concatenate(radial, angular)
// energy(aev, params)
// out(energy, ....) <-- output
//
//
// [Backward]
// dout v
// | v
// ... v
// | v
// aev params denergy aev params v
// \ | / \ | / v
// d a e v dparams v
// / \____ v
// dist dradial \ v
// \ / \ v
// ddist dist coord dangular dist coord v
// \ / / \ | / v
// \_/____/ \___|___/ v
// | __________________/ v
// | / v
// dcoord v
// | v
// ... v
// | v
// out2 v
//
// Functional relationship:
// dout <-- input
// denergy(dout)
// dparams(denergy, aev, params) <-- output
// daev(denergy, aev, params)
// dradial = slice(daev)
// dangular = slice(daev)
// ddist = radial_backward(dradial, dist) + angular_backward_dist(dangular, ...)
// = radial_backward(dradial, dist) + 0 (all contributions route to dcoord)
// = radial_backward(dradial, dist)
// dcoord = dist_backward(ddist, coord, dist) + angular_backward_coord(dangular, coord, dist)
// out2(dcoord, ...) <-- output
//
//
// [Double backward w.r.t params (i.e. force training)]
// Note: only a very limited subset of double backward is implemented
// currently it can only do force training, there is no hessian support
// not implemented terms are marked by $s
// $$$ [dparams] $$$$ ^
// \_ | __/ ^
// [ddaev] ^
// / \_____ ^
// $$$$ [ddradial] \ ^
// \ / \ ^
// [dddist] $$$$ $$$$ [ddangular] $$$$ $$$$ ^
// \ / / \ | / ^
// \_/____/ \_____|___/ ^
// | _____________________/ ^
// | / ^
// [ddcoord] ^
// | ^
// ... ^
// | ^
// [dout2] ^
//
// Functional relationship:
// dout2 <-- input
// ddcoord(dout2, ...)
// dddist = dist_doublebackward(ddcoord, coord, dist)
// ddradial = radial_doublebackward(dddist, dist)
// ddangular = angular_doublebackward(ddcord, coord, dist)
// ddaev = concatenate(ddradial, ddangular)
// dparams(ddaev, ...) <-- output
struct AEVScalarParams {
float Rcr;
float Rca;
int radial_sublength;
int radial_length;
int angular_sublength;
int angular_length;
int num_species;
Tensor EtaR_t;
Tensor ShfR_t;
Tensor EtaA_t;
Tensor Zeta_t;
Tensor ShfA_t;
Tensor ShfZ_t;
AEVScalarParams(
float Rcr_,
float Rca_,
Tensor EtaR_t_,
Tensor ShfR_t_,
Tensor EtaA_t_,
Tensor Zeta_t_,
Tensor ShfA_t_,
Tensor ShfZ_t_,
int num_species_);
};
struct Result {
Tensor aev_t;
Tensor tensor_Rij;
Tensor tensor_radialRij;
Tensor tensor_angularRij;
int total_natom_pairs;
int nRadialRij;
int nAngularRij;
Tensor tensor_centralAtom;
Tensor tensor_numPairsPerCenterAtom;
Tensor tensor_centerAtomStartIdx;
int maxnbrs_per_atom_aligned;
int angular_length_aligned;
int ncenter_atoms;
Tensor coordinates_t;
Tensor species_t;
Result(
Tensor aev_t_,
Tensor tensor_Rij_,
Tensor tensor_radialRij_,
Tensor tensor_angularRij_,
int64_t total_natom_pairs_,
int64_t nRadialRij_,
int64_t nAngularRij_,
Tensor tensor_centralAtom_,
Tensor tensor_numPairsPerCenterAtom_,
Tensor tensor_centerAtomStartIdx_,
int64_t maxnbrs_per_atom_aligned_,
int64_t angular_length_aligned_,
int64_t ncenter_atoms_,
Tensor coordinates_t_,
Tensor species_t_);
Result(tensor_list tensors);
operator tensor_list() {
return {Tensor(), // aev_t got removed
tensor_Rij,
tensor_radialRij,
tensor_angularRij,
torch::tensor(total_natom_pairs),
torch::tensor(nRadialRij),
torch::tensor(nAngularRij),
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
torch::tensor(maxnbrs_per_atom_aligned),
torch::tensor(angular_length_aligned),
torch::tensor(ncenter_atoms),
coordinates_t,
species_t};
}
};
// cuda kernels
Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const AEVScalarParams& aev_params);
Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_params, const Result& result);
Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& aev_params, const Result& result);
// CuaevComputer
// Only keep one copy of aev parameters
struct CuaevComputer : torch::CustomClassHolder {
AEVScalarParams aev_params;
CuaevComputer(
double Rcr,
double Rca,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
int64_t num_species);
Result forward(const Tensor& coordinates_t, const Tensor& species_t) {
return cuaev_forward(coordinates_t, species_t, aev_params);
}
Tensor backward(const Tensor& grad_e_aev, const Result& result) {
return cuaev_backward(grad_e_aev, aev_params, result); // force
}
Tensor double_backward(const Tensor& grad_force, const Result& result) {
return cuaev_double_backward(grad_force, aev_params, result); // grad_grad_aev
}
};
// Autograd functions
class CuaevDoubleAutograd : public torch::autograd::Function<CuaevDoubleAutograd> {
public:
static Tensor forward(
AutogradContext* ctx,
Tensor grad_e_aev,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer,
tensor_list result_tensors);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
public:
static Tensor forward(
AutogradContext* ctx,
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
#include <aev.h>
#include <torch/extension.h>
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::tensor_list;
AEVScalarParams::AEVScalarParams(
float Rcr_,
float Rca_,
Tensor EtaR_t_,
Tensor ShfR_t_,
Tensor EtaA_t_,
Tensor Zeta_t_,
Tensor ShfA_t_,
Tensor ShfZ_t_,
int num_species_)
: Rcr(Rcr_),
Rca(Rca_),
radial_sublength(EtaR_t_.size(0) * ShfR_t_.size(0)),
angular_sublength(EtaA_t_.size(0) * Zeta_t_.size(0) * ShfA_t_.size(0) * ShfZ_t_.size(0)),
num_species(num_species_),
EtaR_t(EtaR_t_),
ShfR_t(ShfR_t_),
EtaA_t(EtaA_t_),
Zeta_t(Zeta_t_),
ShfA_t(ShfA_t_),
ShfZ_t(ShfZ_t_) {
radial_length = radial_sublength * num_species;
angular_length = angular_sublength * (num_species * (num_species + 1) / 2);
}
Result::Result(
Tensor aev_t_,
Tensor tensor_Rij_,
Tensor tensor_radialRij_,
Tensor tensor_angularRij_,
int64_t total_natom_pairs_,
int64_t nRadialRij_,
int64_t nAngularRij_,
Tensor tensor_centralAtom_,
Tensor tensor_numPairsPerCenterAtom_,
Tensor tensor_centerAtomStartIdx_,
int64_t maxnbrs_per_atom_aligned_,
int64_t angular_length_aligned_,
int64_t ncenter_atoms_,
Tensor coordinates_t_,
Tensor species_t_)
: aev_t(aev_t_),
tensor_Rij(tensor_Rij_),
tensor_radialRij(tensor_radialRij_),
tensor_angularRij(tensor_angularRij_),
total_natom_pairs(total_natom_pairs_),
nRadialRij(nRadialRij_),
nAngularRij(nAngularRij_),
tensor_centralAtom(tensor_centralAtom_),
tensor_numPairsPerCenterAtom(tensor_numPairsPerCenterAtom_),
tensor_centerAtomStartIdx(tensor_centerAtomStartIdx_),
maxnbrs_per_atom_aligned(maxnbrs_per_atom_aligned_),
angular_length_aligned(angular_length_aligned_),
ncenter_atoms(ncenter_atoms_),
coordinates_t(coordinates_t_),
species_t(species_t_) {}
Result::Result(tensor_list tensors)
: aev_t(tensors[0]), // aev_t will be a undefined tensor
tensor_Rij(tensors[1]),
tensor_radialRij(tensors[2]),
tensor_angularRij(tensors[3]),
total_natom_pairs(tensors[4].item<int>()),
nRadialRij(tensors[5].item<int>()),
nAngularRij(tensors[6].item<int>()),
tensor_centralAtom(tensors[7]),
tensor_numPairsPerCenterAtom(tensors[8]),
tensor_centerAtomStartIdx(tensors[9]),
maxnbrs_per_atom_aligned(tensors[10].item<int>()),
angular_length_aligned(tensors[11].item<int>()),
ncenter_atoms(tensors[12].item<int>()),
coordinates_t(tensors[13]),
species_t(tensors[14]) {}
CuaevComputer::CuaevComputer(
double Rcr,
double Rca,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
int64_t num_species)
: aev_params(Rcr, Rca, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species) {}
Tensor CuaevDoubleAutograd::forward(
AutogradContext* ctx,
Tensor grad_e_aev,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer,
tensor_list result_tensors) {
Tensor grad_coord = cuaev_computer->backward(grad_e_aev, result_tensors);
if (grad_e_aev.requires_grad()) {
ctx->saved_data["cuaev_computer"] = cuaev_computer;
ctx->save_for_backward(result_tensors);
}
return grad_coord;
}
tensor_list CuaevDoubleAutograd::backward(AutogradContext* ctx, tensor_list grad_outputs) {
Tensor grad_force = grad_outputs[0];
torch::intrusive_ptr<CuaevComputer> cuaev_computer = ctx->saved_data["cuaev_computer"].toCustomClass<CuaevComputer>();
Tensor grad_grad_aev = cuaev_computer->double_backward(grad_force, ctx->get_saved_variables());
return {grad_grad_aev, torch::Tensor(), torch::Tensor()};
}
Tensor CuaevAutograd::forward(
AutogradContext* ctx,
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
at::AutoNonVariableTypeMode g;
Result result = cuaev_computer->forward(coordinates_t, species_t);
if (coordinates_t.requires_grad()) {
ctx->saved_data["cuaev_computer"] = cuaev_computer;
ctx->save_for_backward(result);
}
return result.aev_t;
}
tensor_list CuaevAutograd::backward(AutogradContext* ctx, tensor_list grad_outputs) {
torch::intrusive_ptr<CuaevComputer> cuaev_computer = ctx->saved_data["cuaev_computer"].toCustomClass<CuaevComputer>();
tensor_list result_tensors = ctx->get_saved_variables();
Tensor grad_coord = CuaevDoubleAutograd::apply(grad_outputs[0], cuaev_computer, result_tensors);
return {grad_coord, Tensor(), Tensor()};
}
Tensor run_only_forward(
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
Result result = cuaev_computer->forward(coordinates_t, species_t);
return result.aev_t;
}
Tensor run_autograd(
const Tensor& coordinates_t,
const Tensor& species_t,
const torch::intrusive_ptr<CuaevComputer>& cuaev_computer) {
return CuaevAutograd::apply(coordinates_t, species_t, cuaev_computer);
}
TORCH_LIBRARY(cuaev, m) {
m.class_<CuaevComputer>("CuaevComputer")
.def(torch::init<double, double, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int64_t>());
m.def("run", run_only_forward);
}
TORCH_LIBRARY_IMPL(cuaev, CUDA, m) {
m.impl("run", run_only_forward);
}
TORCH_LIBRARY_IMPL(cuaev, Autograd, m) {
m.impl("run", run_autograd);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment