Unverified Commit 23c9816c authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

CUAEV backward (#554)



* preparation

* radial preparation 30%

* radial backward kernel done

* reuse Gmr (exp part) result for gradient

* radial kernel every block run by column major, to avoid atomicAdd waiting

* apply code review

* static_cast

* implicit cast

* format

* angular preparation

* angular backward works, but slow, AtomicAdd should be avoided

* angular opti: use share memory to avoid AtomicAdd

* format

* equation optimization

* remove unnecessary shared mem for atomi

* remove a lot (warpsize * nbr) unnecessary shared mem for atomj

* format

* update

* clean

* fix

* fix

* test file

* fix
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 7cf6823a
...@@ -65,24 +65,70 @@ class TestCUAEV(TestCase): ...@@ -65,24 +65,70 @@ class TestCUAEV(TestCase):
[-4.4978600, 0.8211300, 0.5604100], [-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600], [-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]] [0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device) ], device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device) species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
_, cu_aev = self.cuaev_computer((species, coordinates)) _, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev) self.assertEqual(cu_aev, aev)
def testSimpleBackward(self):
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]],
[[-4.1862600, 0.0575700, -0.0381200],
[-3.1689400, 0.0523700, 0.0200000],
[-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}')
self.assertEqual(cuaev_grad, aev_grad, f'\ncuaev_grad: {cuaev_grad}\n aev_grad: {aev_grad}')
def testTripeptideMD(self): def testTripeptideMD(self):
for i in range(100): for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i)) datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _, _, _ = pickle.load(f) coordinates, species, *_ = pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device) coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
species = torch.from_numpy(species).unsqueeze(0).to(self.device) species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
_, cu_aev = self.cuaev_computer((species, coordinates)) _, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev) self.assertEqual(cu_aev, aev)
def testTripeptideMDBackward(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testNIST(self): def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all') datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
...@@ -94,11 +140,33 @@ class TestCUAEV(TestCase): ...@@ -94,11 +140,33 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates)) _, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev) self.assertEqual(cu_aev, aev)
def testNISTBackward(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data:
coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testVeryDenseMolecule(self): def testVeryDenseMolecule(self):
"""
Test very dense molecule for aev correctness, especially for angular part
"""
for i in range(100): for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i)) datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f: with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _, _, _ = pickle.load(f) coordinates, species, *_ = pickle.load(f)
# change angstrom coordinates to 10 times smaller # change angstrom coordinates to 10 times smaller
coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device) coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
species = torch.from_numpy(species).unsqueeze(0).to(self.device) species = torch.from_numpy(species).unsqueeze(0).to(self.device)
...@@ -106,6 +174,28 @@ class TestCUAEV(TestCase): ...@@ -106,6 +174,28 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates)) _, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5) self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
def testVeryDenseMoleculeBackward(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
# change angstrom coordinates to 10 times smaller
coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
coordinates.requires_grad_(True)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-4, rtol=5e-4)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#include <thrust/equal.h> #include <thrust/equal.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <vector>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <THC/THC.h> #include <THC/THC.h>
...@@ -8,17 +9,38 @@ ...@@ -8,17 +9,38 @@
#include <THC/THCThrustAllocator.cuh> #include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793 #define PI 3.141592653589793
using torch::Tensor;
using torch::autograd::tensor_list;
template <typename DataT, typename IndexT = int> template <typename DataT, typename IndexT = int>
struct AEVScalarParams { struct AEVScalarParams {
DataT Rcr; DataT Rcr;
DataT Rca; DataT Rca;
IndexT radial_sublength; IndexT radial_sublength;
IndexT radial_length; IndexT radial_length;
IndexT angular_sublength; IndexT angular_sublength;
IndexT angular_length; IndexT angular_length;
IndexT num_species; IndexT num_species;
AEVScalarParams() = default;
AEVScalarParams(const torch::IValue& aev_params_ivalue) {
c10::intrusive_ptr<c10::ivalue::Tuple> aev_params_tuple_ptr = aev_params_ivalue.toTuple();
auto aev_params_tuple = aev_params_tuple_ptr->elements();
Rcr = static_cast<DataT>(aev_params_tuple[0].toDouble());
Rca = static_cast<DataT>(aev_params_tuple[1].toDouble());
radial_sublength = static_cast<IndexT>(aev_params_tuple[2].toInt());
radial_length = static_cast<IndexT>(aev_params_tuple[3].toInt());
angular_sublength = static_cast<IndexT>(aev_params_tuple[4].toInt());
angular_length = static_cast<IndexT>(aev_params_tuple[5].toInt());
num_species = static_cast<IndexT>(aev_params_tuple[6].toInt());
}
operator torch::IValue() {
return torch::IValue(std::make_tuple(
(double)Rcr, (double)Rca, radial_sublength, radial_length, angular_sublength, angular_length, num_species));
}
}; };
#define MAX_NSPECIES 10 #define MAX_NSPECIES 10
...@@ -105,6 +127,42 @@ __global__ void pairwiseDistance( ...@@ -105,6 +127,42 @@ __global__ void pairwiseDistance(
} }
} }
// every block compute blocksize RIJ's gradient by column major, to avoid atomicAdd waiting
template <typename DataT, typename IndexT = int>
__global__ void pairwiseDistance_backward(
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> grad_radial_dist,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_coord,
const PairDist<DataT>* d_radialRij,
IndexT nRadialRij) {
int gidx = threadIdx.x * gridDim.x + blockIdx.x;
if (gidx >= nRadialRij)
return;
PairDist<DataT> d = d_radialRij[gidx];
DataT Rij = d.Rij;
int mol_idx = d.midx;
int i = d.i;
int j = d.j;
const DataT delx = pos_t[mol_idx][j][0] - pos_t[mol_idx][i][0];
const DataT dely = pos_t[mol_idx][j][1] - pos_t[mol_idx][i][1];
const DataT delz = pos_t[mol_idx][j][2] - pos_t[mol_idx][i][2];
DataT grad_dist_coord_x = delx / Rij;
DataT grad_dist_coord_y = dely / Rij;
DataT grad_dist_coord_z = delz / Rij;
DataT grad_radial_dist_item = grad_radial_dist[gidx];
atomicAdd(&grad_coord[mol_idx][j][0], grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][j][1], grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][j][2], grad_radial_dist_item * grad_dist_coord_z);
atomicAdd(&grad_coord[mol_idx][i][0], -grad_radial_dist_item * grad_dist_coord_x);
atomicAdd(&grad_coord[mol_idx][i][1], -grad_radial_dist_item * grad_dist_coord_y);
atomicAdd(&grad_coord[mol_idx][i][2], -grad_radial_dist_item * grad_dist_coord_z);
}
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4> template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
__global__ void cuAngularAEVs( __global__ void cuAngularAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t, torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
...@@ -124,7 +182,8 @@ __global__ void cuAngularAEVs( ...@@ -124,7 +182,8 @@ __global__ void cuAngularAEVs(
int ncentral_atoms) { int ncentral_atoms) {
extern __shared__ DataT smem[]; extern __shared__ DataT smem[];
int threads_per_catom = TILEX * TILEY; constexpr int threads_per_catom = TILEX * TILEY;
static_assert(threads_per_catom == C10_WARP_SIZE);
int gIdx = blockIdx.x * blockDim.x + threadIdx.x; int gIdx = blockIdx.x * blockDim.x + threadIdx.x;
int cIdx = gIdx / threads_per_catom; // central atom id int cIdx = gIdx / threads_per_catom; // central atom id
...@@ -194,6 +253,8 @@ __global__ void cuAngularAEVs( ...@@ -194,6 +253,8 @@ __global__ void cuAngularAEVs(
} }
short2 tile = make_short2(laneIdx % TILEX, laneIdx / TILEX); short2 tile = make_short2(laneIdx % TILEX, laneIdx / TILEX);
// must sync if threads_per_catom != 32 (wrap size) to make sure shared data is ready
// __syncthreads
for (int jj = 0; jj < jnum; jj++) { for (int jj = 0; jj < jnum; jj++) {
const DataT Rij = sdist[jj]; const DataT Rij = sdist[jj];
...@@ -209,7 +270,7 @@ __global__ void cuAngularAEVs( ...@@ -209,7 +270,7 @@ __global__ void cuAngularAEVs(
theta = acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik)); theta = acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik));
} }
for (int srcLane = 0; srcLane < 32 && (kk_start + srcLane) < jnum; ++srcLane) { for (int srcLane = 0; srcLane < C10_WARP_SIZE && (kk_start + srcLane) < jnum; ++srcLane) {
int kk = kk_start + srcLane; int kk = kk_start + srcLane;
DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane); DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane);
...@@ -222,7 +283,6 @@ __global__ void cuAngularAEVs( ...@@ -222,7 +283,6 @@ __global__ void cuAngularAEVs(
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k]; IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k];
IndexT aev_offset = aev_params.radial_length + subaev_offset;
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -247,6 +307,258 @@ __global__ void cuAngularAEVs( ...@@ -247,6 +307,258 @@ __global__ void cuAngularAEVs(
} }
} }
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
__global__ void
// __launch_bounds__(32)
cuAngularAEVs_backward(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_output,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_coord,
const PairDist<DataT>* d_Rij,
const PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom,
int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncentral_atoms) {
extern __shared__ DataT smem[];
constexpr int threads_per_catom = TILEX * TILEY;
static_assert(threads_per_catom == C10_WARP_SIZE);
int gIdx = blockIdx.x * blockDim.x + threadIdx.x;
int cIdx = gIdx / threads_per_catom; // central atom id
if (cIdx >= ncentral_atoms)
return;
int groupIdx = threadIdx.x / threads_per_catom;
int laneIdx = threadIdx.x % threads_per_catom;
int ncatom_per_tpb = blockDim.x / threads_per_catom; // e.g. 2 catom per block
DataT* sdx = &smem[groupIdx * maxnbrs_per_atom_aligned];
int offset = ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdjx_grad = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdjy_grad = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdjz_grad = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sfc_grad = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
int* stype = (int*)&smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT EtaA = EtaA_t[0];
DataT Zeta = Zeta_t[0];
IndexT nShfA = ShfA_t.size(0);
IndexT nShfZ = ShfZ_t.size(0);
DataT Rca = aev_params.Rca;
IndexT num_species = aev_params.num_species;
PairDist<DataT> d = d_centralAtom[cIdx];
int start_idx = d_centerAtomStartIdx[cIdx];
int jnum = d_nPairsPerCenterAtom[cIdx];
// center atom
int i = d.i;
int mol_idx = d.midx;
DataT xi = pos_t[mol_idx][i][0];
DataT yi = pos_t[mol_idx][i][1];
DataT zi = pos_t[mol_idx][i][2];
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
PairDist<DataT> dij = d_Rij[start_idx + jj];
int j = dij.j;
DataT Rij = dij.Rij;
SpeciesT type_j = species_t[mol_idx][j];
sdx[jj] = pos_t[mol_idx][j][0] - xi;
sdy[jj] = pos_t[mol_idx][j][1] - yi;
sdz[jj] = pos_t[mol_idx][j][2] - zi;
stype[jj] = type_j;
sdist[jj] = Rij;
// cutoff
DataT fc_ij = 0.5 * cos(PI * Rij / Rca) + 0.5;
DataT fc_ij_grad = -0.5 * (PI / Rca) * sin(PI * Rij / Rca);
sfc[jj] = fc_ij;
sfc_grad[jj] = fc_ij_grad;
}
// grad init
DataT sdix_grad = 0;
DataT sdiy_grad = 0;
DataT sdiz_grad = 0;
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
sdjx_grad[jj] = 0;
sdjy_grad[jj] = 0;
sdjz_grad[jj] = 0;
}
short2 tile = make_short2(laneIdx % TILEX, laneIdx / TILEX);
const DataT tc = 0.95; // theta constant factor
// must sync if threads_per_catom != 32 (wrap size) to make sure shared data is ready
// __syncthreads
for (int jj = 0; jj < jnum; jj++) {
const DataT Rij = sdist[jj];
SpeciesT type_j = stype[jj];
DataT fc_ij = sfc[jj];
DataT grad_fc_ij = sfc_grad[jj];
for (int kk_start = jj + 1; kk_start < jnum; kk_start += threads_per_catom) {
int kk = kk_start + laneIdx;
DataT theta = 0;
DataT grad_theta_vij_x = 0;
DataT grad_theta_vij_y = 0;
DataT grad_theta_vij_z = 0;
DataT grad_theta_vik_x = 0;
DataT grad_theta_vik_y = 0;
DataT grad_theta_vik_z = 0;
if (kk < jnum) {
const DataT Rik = sdist[kk];
DataT vij_vik_dot = sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk];
theta = acos(tc * vij_vik_dot / (Rij * Rik));
// grad
DataT vij_factor =
tc / (Rij * Rij * Rij * sqrt(-tc * tc * vij_vik_dot * vij_vik_dot / (Rij * Rij) + Rik * Rik));
DataT vik_factor = tc /
(Rik * Rik * Rik *
sqrt(-tc * tc * vij_vik_dot * vij_vik_dot / (Rik * Rik) + Rij * Rij)); // tricky 80ms improved
grad_theta_vij_x = vij_factor * (sdx[jj] * vij_vik_dot - sdx[kk] * Rij * Rij);
grad_theta_vij_y = vij_factor * (sdy[jj] * vij_vik_dot - sdy[kk] * Rij * Rij);
grad_theta_vij_z = vij_factor * (sdz[jj] * vij_vik_dot - sdz[kk] * Rij * Rij);
grad_theta_vik_x = vik_factor * (sdx[kk] * vij_vik_dot - sdx[jj] * Rik * Rik);
grad_theta_vik_y = vik_factor * (sdy[kk] * vij_vik_dot - sdy[jj] * Rik * Rik);
grad_theta_vik_z = vik_factor * (sdz[kk] * vij_vik_dot - sdz[jj] * Rik * Rik);
}
for (int srcLane = 0; srcLane < C10_WARP_SIZE && (kk_start + srcLane) < jnum; ++srcLane) {
int kk = kk_start + srcLane;
DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane);
// TODO necessary?
DataT grad_theta_vij_x_ = __shfl_sync(0xFFFFFFFF, grad_theta_vij_x, srcLane);
DataT grad_theta_vij_y_ = __shfl_sync(0xFFFFFFFF, grad_theta_vij_y, srcLane);
DataT grad_theta_vij_z_ = __shfl_sync(0xFFFFFFFF, grad_theta_vij_z, srcLane);
DataT grad_theta_vik_x_ = __shfl_sync(0xFFFFFFFF, grad_theta_vik_x, srcLane);
DataT grad_theta_vik_y_ = __shfl_sync(0xFFFFFFFF, grad_theta_vik_y, srcLane);
DataT grad_theta_vik_z_ = __shfl_sync(0xFFFFFFFF, grad_theta_vik_z, srcLane);
const DataT Rik = sdist[kk];
SpeciesT type_k = stype[kk];
DataT fc_ik = sfc[kk];
DataT grad_fc_ik = sfc_grad[kk];
DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k];
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta];
DataT factor1 = pow((1 + cos(theta_ijk - ShfZ)) / 2, Zeta);
DataT grad_factor1_theta = 1.0 / 2.0 * Zeta * pow((1 + cos(ShfZ - theta_ijk)) / 2, Zeta - 1) *
sin(ShfZ - theta_ijk); // tricky 100ms improved
for (int ishfr = tile.y; ishfr < nShfA; ishfr += TILEY) {
DataT ShfA = ShfA_t[ishfr];
DataT factor2 = exp(-EtaA * (Rijk - ShfA) * (Rijk - ShfA));
DataT grad_factor2_dist = -EtaA * (Rijk - ShfA) * factor2;
DataT grad_output_item =
grad_output[mol_idx][i][aev_params.radial_length + subaev_offset + ishfr * nShfZ + itheta];
DataT grad_vij_x = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vij_x_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdx[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdx[jj] / Rij);
DataT grad_vij_y = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vij_y_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdy[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdy[jj] / Rij);
DataT grad_vij_z = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vij_z_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdz[jj] / Rij * fc_ijk +
factor1 * factor2 * fc_ik * grad_fc_ij * sdz[jj] / Rij);
DataT grad_vik_x = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vik_x_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdx[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdx[kk] / Rik);
DataT grad_vik_y = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vik_y_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdy[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdy[kk] / Rik);
DataT grad_vik_z = 2 * grad_output_item *
(grad_factor1_theta * grad_theta_vik_z_ * factor2 * fc_ijk +
factor1 * grad_factor2_dist * sdz[kk] / Rik * fc_ijk +
factor1 * factor2 * fc_ij * grad_fc_ik * sdz[kk] / Rik);
sdix_grad += (-grad_vij_x - grad_vik_x);
sdiy_grad += (-grad_vij_y - grad_vik_y);
sdiz_grad += (-grad_vij_z - grad_vik_z);
for (int offset = 16; offset > 0; offset /= 2) {
grad_vij_x += __shfl_down_sync(0xFFFFFFFF, grad_vij_x, offset);
grad_vij_y += __shfl_down_sync(0xFFFFFFFF, grad_vij_y, offset);
grad_vij_z += __shfl_down_sync(0xFFFFFFFF, grad_vij_z, offset);
grad_vik_x += __shfl_down_sync(0xFFFFFFFF, grad_vik_x, offset);
grad_vik_y += __shfl_down_sync(0xFFFFFFFF, grad_vik_y, offset);
grad_vik_z += __shfl_down_sync(0xFFFFFFFF, grad_vik_z, offset);
}
if (laneIdx == 0) {
sdjx_grad[jj] += grad_vij_x;
sdjy_grad[jj] += grad_vij_y;
sdjz_grad[jj] += grad_vij_z;
sdjx_grad[kk] += grad_vik_x;
sdjy_grad[kk] += grad_vik_y;
sdjz_grad[kk] += grad_vik_z;
}
}
}
}
}
}
int atomi_idx = i;
atomicAdd(&grad_coord[mol_idx][atomi_idx][0], sdix_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][1], sdiy_grad);
atomicAdd(&grad_coord[mol_idx][atomi_idx][2], sdiz_grad);
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
int atomj_idx = d_Rij[start_idx + jj].j;
atomicAdd(&grad_coord[mol_idx][atomj_idx][0], sdjx_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][1], sdjy_grad[jj]);
atomicAdd(&grad_coord[mol_idx][atomj_idx][2], sdjz_grad[jj]);
}
}
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ> template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs( __global__ void cuRadialAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t, torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
...@@ -273,7 +585,6 @@ __global__ void cuRadialAEVs( ...@@ -273,7 +585,6 @@ __global__ void cuRadialAEVs(
int i = d.i; int i = d.i;
int j = d.j; int j = d.j;
SpeciesT type_i = species_t[mol_idx][i];
SpeciesT type_j = species_t[mol_idx][j]; SpeciesT type_j = species_t[mol_idx][j];
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5; DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5;
...@@ -287,6 +598,52 @@ __global__ void cuRadialAEVs( ...@@ -287,6 +598,52 @@ __global__ void cuRadialAEVs(
} }
} }
// every <THREADS_PER_RIJ> threads take care of 1 RIJ, and iterate <nShfR / THREADS_PER_RIJ> times
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs_backward(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> grad_output,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> grad_radial_dist,
const PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params,
int nRadialRij) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = gidx / THREADS_PER_RIJ;
int nShfR = ShfR_t.size(0);
DataT EtaR = EtaR_t[0];
if (idx >= nRadialRij)
return;
int laneIdx = threadIdx.x % THREADS_PER_RIJ;
PairDist<DataT> d = d_Rij[idx];
DataT Rij = d.Rij;
int mol_idx = d.midx;
int i = d.i;
int j = d.j;
SpeciesT type_j = species_t[mol_idx][j];
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5;
DataT fc_grad = -0.5 * (PI / aev_params.Rcr) * sin(PI * Rij / aev_params.Rcr);
for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) {
DataT ShfR = ShfR_t[ishfr];
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR));
DataT GmR_grad = -EtaR * (-2 * ShfR + 2 * Rij) * GmR;
DataT grad_output_item = grad_output[mol_idx][i][type_j * aev_params.radial_sublength + ishfr];
DataT grad_radial_dist_item = grad_output_item * (GmR_grad * fc + GmR * fc_grad);
atomicAdd(&grad_radial_dist[idx], grad_radial_dist_item);
}
}
template <typename DataT> template <typename DataT>
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) { void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
...@@ -406,19 +763,36 @@ void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) { ...@@ -406,19 +763,36 @@ void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
delete[] subaev_offsets; delete[] subaev_offsets;
} }
struct Result {
Tensor aev_t;
AEVScalarParams<float> aev_params;
Tensor tensor_Rij;
Tensor tensor_radialRij;
Tensor tensor_angularRij;
int total_natom_pairs;
int nRadialRij;
int nAngularRij;
Tensor tensor_centralAtom;
Tensor tensor_numPairsPerCenterAtom;
Tensor tensor_centerAtomStartIdx;
int maxnbrs_per_atom_aligned;
int angular_length_aligned;
int ncenter_atoms;
};
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1 // NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
template <typename ScalarRealT = float> template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV( Result cuaev_forward(
torch::Tensor coordinates_t, const Tensor& coordinates_t,
torch::Tensor species_t, const Tensor& species_t,
double Rcr_, double Rcr_,
double Rca_, double Rca_,
torch::Tensor EtaR_t, const Tensor& EtaR_t,
torch::Tensor ShfR_t, const Tensor& ShfR_t,
torch::Tensor EtaA_t, const Tensor& EtaA_t,
torch::Tensor Zeta_t, const Tensor& Zeta_t,
torch::Tensor ShfA_t, const Tensor& ShfA_t,
torch::Tensor ShfZ_t, const Tensor& ShfZ_t,
int64_t num_species_) { int64_t num_species_) {
TORCH_CHECK( TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type"); (species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
...@@ -450,7 +824,7 @@ torch::Tensor cuComputeAEV( ...@@ -450,7 +824,7 @@ torch::Tensor cuComputeAEV(
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options()); auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options());
if (species_t.numel() == 0) { if (species_t.numel() == 0) {
return aev_t; return {aev_t, aev_params, Tensor(), Tensor(), Tensor(), 0, 0, 0};
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -463,8 +837,9 @@ torch::Tensor cuComputeAEV( ...@@ -463,8 +837,9 @@ torch::Tensor cuComputeAEV(
// buffer to store all the pairwise distance (Rij) // buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol; auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto buffer_Rij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs); auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device());
PairDist<float>* d_Rij = (PairDist<float>*)buffer_Rij.get(); Tensor tensor_Rij = torch::empty(sizeof(PairDist<float>) * total_natom_pairs, d_options);
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr();
// init all Rij to inf // init all Rij to inf
PairDist<float> init; PairDist<float> init;
...@@ -473,8 +848,8 @@ torch::Tensor cuComputeAEV( ...@@ -473,8 +848,8 @@ torch::Tensor cuComputeAEV(
// buffer to store all the pairwise distance that is needed for Radial AEV // buffer to store all the pairwise distance that is needed for Radial AEV
// computation // computation
auto buffer_radialRij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs); Tensor tensor_radialRij = torch::empty(sizeof(PairDist<float>) * total_natom_pairs, d_options);
PairDist<float>* d_radialRij = (PairDist<float>*)buffer_radialRij.get(); PairDist<float>* d_radialRij = (PairDist<float>*)tensor_radialRij.data_ptr();
auto buffer_count = allocator.allocate(sizeof(int)); auto buffer_count = allocator.allocate(sizeof(int));
int* d_count_out = (int*)buffer_count.get(); int* d_count_out = (int*)buffer_count.get();
...@@ -483,14 +858,14 @@ torch::Tensor cuComputeAEV( ...@@ -483,14 +858,14 @@ torch::Tensor cuComputeAEV(
dim3 block(8, 8, 1); dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule // Compute pairwise distance (Rij) for all atom pairs in a molecule
// maximum 4096 atoms, which needs 49152 byte (48 kb) of shared memory
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>( pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij, d_Rij,
max_natoms_per_mol); max_natoms_per_mol);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= // Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr
// Rcr
int nRadialRij = cubDeviceSelect( int nRadialRij = cubDeviceSelect(
d_Rij, d_Rij,
d_radialRij, d_radialRij,
...@@ -511,7 +886,8 @@ torch::Tensor cuComputeAEV( ...@@ -511,7 +886,8 @@ torch::Tensor cuComputeAEV(
// reuse buffer allocated for all Rij // reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation // d_angularRij will store all the Rij required in Angular AEV computation
PairDist<float>* d_angularRij = d_Rij; Tensor tensor_angularRij = torch::empty(sizeof(PairDist<float>) * nRadialRij, d_options);
PairDist<float>* d_angularRij = (PairDist<float>*)tensor_angularRij.data_ptr();
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij // Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
// <= Rca // <= Rca
...@@ -523,24 +899,24 @@ torch::Tensor cuComputeAEV( ...@@ -523,24 +899,24 @@ torch::Tensor cuComputeAEV(
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; }, [=] __device__(const PairDist<float> d) { return d.Rij <= Rca; },
stream); stream);
auto buffer_centralAtom = allocator.allocate(sizeof(PairDist<float>) * nAngularRij); Tensor tensor_centralAtom = torch::empty(sizeof(PairDist<float>) * nAngularRij, d_options);
PairDist<float>* d_centralAtom = (PairDist<float>*)buffer_centralAtom.get(); PairDist<float>* d_centralAtom = (PairDist<float>*)tensor_centralAtom.data_ptr();
auto buffer_numPairsPerCenterAtom = allocator.allocate(sizeof(int) * nAngularRij); Tensor tensor_numPairsPerCenterAtom = torch::empty(sizeof(int) * nAngularRij, d_options);
int* d_numPairsPerCenterAtom = (int*)buffer_numPairsPerCenterAtom.get(); int* d_numPairsPerCenterAtom = (int*)tensor_numPairsPerCenterAtom.data_ptr();
// group by center atom // group by center atom
int ncenter_atoms = cubEncode(d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream); int ncenter_atoms = cubEncode(d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream);
auto buffer_centerAtomStartIdx = allocator.allocate(sizeof(int) * ncenter_atoms); Tensor tensor_centerAtomStartIdx = torch::empty(sizeof(int) * ncenter_atoms, d_options);
int* d_centerAtomStartIdx = (int*)buffer_centerAtomStartIdx.get(); int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr();
cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream); cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream);
{ {
const int nthreads_per_catom = 32; const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size; const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sm_aev = sizeof(float) * align<4>(aev_params.angular_length); int sm_aev = sizeof(float) * align<4>(aev_params.angular_length); // (angular_length / 4 + 1) * 4
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
int sRij = sizeof(float) * max_nbrs; int sRij = sizeof(float) * max_nbrs;
int sfc = sizeof(float) * max_nbrs; int sfc = sizeof(float) * max_nbrs;
...@@ -550,14 +926,11 @@ torch::Tensor cuComputeAEV( ...@@ -550,14 +926,11 @@ torch::Tensor cuComputeAEV(
}; };
int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream); int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom); int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom);
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
int angular_length_aligned = align<4>(aev_params.angular_length);
cuAngularAEVs<<< cuAngularAEVs<<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
nblocks_angAEV,
block_size,
smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom),
stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
...@@ -571,14 +944,227 @@ torch::Tensor cuComputeAEV( ...@@ -571,14 +944,227 @@ torch::Tensor cuComputeAEV(
d_centerAtomStartIdx, d_centerAtomStartIdx,
aev_params, aev_params,
maxnbrs_per_atom_aligned, maxnbrs_per_atom_aligned,
align<4>(aev_params.angular_length), angular_length_aligned,
ncenter_atoms); ncenter_atoms);
return {aev_t,
aev_params,
tensor_Rij,
tensor_radialRij,
tensor_angularRij,
total_natom_pairs,
nRadialRij,
nAngularRij,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms};
} }
return aev_t; }
Tensor cuaev_backward(
const Tensor& grad_output,
const Tensor& coordinates_t,
const Tensor& species_t,
const AEVScalarParams<float>& aev_params,
const Tensor& EtaR_t,
const Tensor& ShfR_t,
const Tensor& EtaA_t,
const Tensor& Zeta_t,
const Tensor& ShfA_t,
const Tensor& ShfZ_t,
const Tensor& tensor_Rij,
int total_natom_pairs,
const Tensor& tensor_radialRij,
int nRadialRij,
const Tensor& tensor_angularRij,
int nAngularRij,
const Tensor& tensor_centralAtom,
const Tensor& tensor_numPairsPerCenterAtom,
const Tensor& tensor_centerAtomStartIdx,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncenter_atoms) {
using namespace torch::indexing;
const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3]
auto grad_output_radial = grad_output.index({Ellipsis, Slice(None, aev_params.radial_length)}); // [2, 5, 64]
auto grad_output_angular = grad_output.index({Ellipsis, Slice(aev_params.radial_length, None)}); // [2, 5, 320]
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr();
PairDist<float>* d_radialRij = (PairDist<float>*)tensor_radialRij.data_ptr();
PairDist<float>* d_angularRij = (PairDist<float>*)tensor_angularRij.data_ptr();
PairDist<float>* d_centralAtom = (PairDist<float>*)tensor_centralAtom.data_ptr();
int* d_numPairsPerCenterAtom = (int*)tensor_numPairsPerCenterAtom.data_ptr();
int* d_centerAtomStartIdx = (int*)tensor_centerAtomStartIdx.data_ptr();
Tensor grad_radial_dist = torch::zeros(nRadialRij, coordinates_t.options().requires_grad(false));
int block_size = 64;
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward<int, float, 8><<<nblocks, block_size, 0, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
d_radialRij,
aev_params,
nRadialRij);
// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (nRadialRij + block_size - 1) / block_size;
pairwiseDistance_backward<<<nblocks, block_size, 0, stream>>>(
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_radial_dist.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
nRadialRij);
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
int sj_xyz_grad = sizeof(float) * max_nbrs * 3;
int sRij = sizeof(float) * max_nbrs;
int sfc = sizeof(float) * max_nbrs;
int sfc_grad = sizeof(float) * max_nbrs;
int sj = sizeof(int) * max_nbrs;
return (sxyz + sj_xyz_grad + sRij + sfc + sfc_grad + sj) * ncatom_per_tpb;
};
block_size = 32;
const int nthreads_per_catom = 32;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
int smem_size_aligned = smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom);
Tensor grad_angular_coord = torch::zeros({nAngularRij, 3}, coordinates_t.options().requires_grad(false));
cuAngularAEVs_backward<<<nblocks_angAEV, block_size, smem_size_aligned, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
grad_output.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij,
d_centralAtom,
d_numPairsPerCenterAtom,
d_centerAtomStartIdx,
aev_params,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return grad_coord;
}
#define AEV_INPUT \
const Tensor &coordinates_t, const Tensor &species_t, double Rcr_, double Rca_, const Tensor &EtaR_t, \
const Tensor &ShfR_t, const Tensor &EtaA_t, const Tensor &Zeta_t, const Tensor &ShfA_t, const Tensor &ShfZ_t, \
int64_t num_species_
Tensor cuaev_cuda(AEV_INPUT) {
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
return res.aev_t;
}
class CuaevAutograd : public torch::autograd::Function<CuaevAutograd> {
public:
static Tensor forward(torch::autograd::AutogradContext* ctx, AEV_INPUT) {
at::AutoNonVariableTypeMode g;
Result res = cuaev_forward<float>(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
if (coordinates_t.requires_grad()) {
ctx->save_for_backward({coordinates_t,
species_t,
res.tensor_Rij,
res.tensor_radialRij,
res.tensor_angularRij,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
res.tensor_centralAtom,
res.tensor_numPairsPerCenterAtom,
res.tensor_centerAtomStartIdx});
ctx->saved_data["aev_params"] = res.aev_params;
ctx->saved_data["int_list"] = c10::List<int64_t>{res.total_natom_pairs,
res.nRadialRij,
res.nAngularRij,
res.maxnbrs_per_atom_aligned,
res.angular_length_aligned,
res.ncenter_atoms};
}
return res.aev_t;
}
static tensor_list backward(torch::autograd::AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto coordinates_t = saved[0], species_t = saved[1];
auto tensor_Rij = saved[2], tensor_radialRij = saved[3], tensor_angularRij = saved[4];
auto EtaR_t = saved[5], ShfR_t = saved[6], EtaA_t = saved[7], Zeta_t = saved[8], ShfA_t = saved[9],
ShfZ_t = saved[10];
auto tensor_centralAtom = saved[11], tensor_numPairsPerCenterAtom = saved[12],
tensor_centerAtomStartIdx = saved[13];
AEVScalarParams<float> aev_params(ctx->saved_data["aev_params"]);
c10::List<int64_t> int_list = ctx->saved_data["int_list"].toIntList();
int total_natom_pairs = int_list[0], nRadialRij = int_list[1], nAngularRij = int_list[2];
int maxnbrs_per_atom_aligned = int_list[3], angular_length_aligned = int_list[4];
int ncenter_atoms = int_list[5];
Tensor grad_coord = cuaev_backward(
grad_outputs[0],
coordinates_t,
species_t,
aev_params,
EtaR_t,
ShfR_t,
EtaA_t,
Zeta_t,
ShfA_t,
ShfZ_t,
tensor_Rij,
total_natom_pairs,
tensor_radialRij,
nRadialRij,
tensor_angularRij,
nAngularRij,
tensor_centralAtom,
tensor_numPairsPerCenterAtom,
tensor_centerAtomStartIdx,
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
return {
grad_coord, Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()};
}
};
Tensor cuaev_autograd(AEV_INPUT) {
return CuaevAutograd::apply(
coordinates_t, species_t, Rcr_, Rca_, EtaR_t, ShfR_t, EtaA_t, Zeta_t, ShfA_t, ShfZ_t, num_species_);
} }
TORCH_LIBRARY(cuaev, m) { TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", &cuComputeAEV<float>); m.def("cuComputeAEV", cuaev_cuda);
}
TORCH_LIBRARY_IMPL(cuaev, CUDA, m) {
m.impl("cuComputeAEV", cuaev_cuda);
}
TORCH_LIBRARY_IMPL(cuaev, Autograd, m) {
m.impl("cuComputeAEV", cuaev_autograd);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment