Unverified Commit d2d63056 authored by Richard Xue's avatar Richard Xue Committed by GitHub
Browse files

CRLF to LF (#553)

* line-limit 120

* CRLF to LF
parent a6d819ed
...@@ -35,7 +35,7 @@ BreakBeforeTernaryOperators: true ...@@ -35,7 +35,7 @@ BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false BreakStringLiterals: false
ColumnLimit: 100 ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:' CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerAllOnOneLineOrOnePerLine: true
......
#include <thrust/equal.h> #include <thrust/equal.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCThrustAllocator.cuh> #include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793 #define PI 3.141592653589793
template <typename DataT, typename IndexT = int> template <typename DataT, typename IndexT = int>
struct AEVScalarParams { struct AEVScalarParams {
DataT Rcr; DataT Rcr;
DataT Rca; DataT Rca;
IndexT radial_sublength; IndexT radial_sublength;
IndexT radial_length; IndexT radial_length;
IndexT angular_sublength; IndexT angular_sublength;
IndexT angular_length; IndexT angular_length;
IndexT num_species; IndexT num_species;
}; };
#define MAX_NSPECIES 10 #define MAX_NSPECIES 10
__constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES]; __constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES];
template <typename DataT> template <typename DataT>
struct PairDist { struct PairDist {
DataT Rij; DataT Rij;
int midx; int midx;
short i; short i;
short j; short j;
}; };
// used to group Rijs by atom id // used to group Rijs by atom id
template <typename DataT> template <typename DataT>
__host__ __device__ bool operator==(const PairDist<DataT>& lhs, const PairDist<DataT>& rhs) { __host__ __device__ bool operator==(const PairDist<DataT>& lhs, const PairDist<DataT>& rhs) {
return lhs.midx == rhs.midx && lhs.i == rhs.i; return lhs.midx == rhs.midx && lhs.i == rhs.i;
} }
/// Alignment of memory. Must be a power of two /// Alignment of memory. Must be a power of two
/// \tparam boundary Boundary to align to (NOTE: must be power of 2) /// \tparam boundary Boundary to align to (NOTE: must be power of 2)
/// \param value Input value that is to be aligned /// \param value Input value that is to be aligned
/// \return Value aligned to boundary /// \return Value aligned to boundary
template <int32_t boundary> template <int32_t boundary>
__host__ __device__ __forceinline__ int align(const int& value) { __host__ __device__ __forceinline__ int align(const int& value) {
static_assert((boundary & (boundary - 1)) == 0, "Boundary for align must be power of 2"); static_assert((boundary & (boundary - 1)) == 0, "Boundary for align must be power of 2");
return (value + boundary) & ~(boundary - 1); return (value + boundary) & ~(boundary - 1);
} }
template <typename SpeciesT, typename DataT, typename IndexT = int> template <typename SpeciesT, typename DataT, typename IndexT = int>
__global__ void pairwiseDistance( __global__ void pairwiseDistance(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t, torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
PairDist<DataT>* d_Rij, PairDist<DataT>* d_Rij,
IndexT max_natoms_per_mol) { IndexT max_natoms_per_mol) {
extern __shared__ DataT spos[]; extern __shared__ DataT spos[];
DataT* sx = &spos[0]; DataT* sx = &spos[0];
DataT* sy = &spos[max_natoms_per_mol]; DataT* sy = &spos[max_natoms_per_mol];
DataT* sz = &spos[2 * max_natoms_per_mol]; DataT* sz = &spos[2 * max_natoms_per_mol];
int mol_idx = blockIdx.x; int mol_idx = blockIdx.x;
int tidx = threadIdx.y * blockDim.x + threadIdx.x; int tidx = threadIdx.y * blockDim.x + threadIdx.x;
for (int i = tidx; i < max_natoms_per_mol; i += blockDim.x * blockDim.y) { for (int i = tidx; i < max_natoms_per_mol; i += blockDim.x * blockDim.y) {
sx[i] = pos_t[mol_idx][i][0]; sx[i] = pos_t[mol_idx][i][0];
sy[i] = pos_t[mol_idx][i][1]; sy[i] = pos_t[mol_idx][i][1];
sz[i] = pos_t[mol_idx][i][2]; sz[i] = pos_t[mol_idx][i][2];
} }
__syncthreads(); __syncthreads();
int natom_pairs = max_natoms_per_mol * max_natoms_per_mol; int natom_pairs = max_natoms_per_mol * max_natoms_per_mol;
for (int i = threadIdx.y; i < max_natoms_per_mol; i += blockDim.y) { for (int i = threadIdx.y; i < max_natoms_per_mol; i += blockDim.y) {
SpeciesT type_i = species_t[mol_idx][i]; SpeciesT type_i = species_t[mol_idx][i];
DataT xi = sx[i]; DataT xi = sx[i];
DataT yi = sy[i]; DataT yi = sy[i];
DataT zi = sz[i]; DataT zi = sz[i];
for (int j = threadIdx.x; j < max_natoms_per_mol; j += blockDim.x) { for (int j = threadIdx.x; j < max_natoms_per_mol; j += blockDim.x) {
SpeciesT type_j = species_t[mol_idx][j]; SpeciesT type_j = species_t[mol_idx][j];
const DataT xj = sx[j]; const DataT xj = sx[j];
const DataT yj = sy[j]; const DataT yj = sy[j];
const DataT zj = sz[j]; const DataT zj = sz[j];
const DataT delx = xj - xi; const DataT delx = xj - xi;
const DataT dely = yj - yi; const DataT dely = yj - yi;
const DataT delz = zj - zi; const DataT delz = zj - zi;
const DataT Rsq = delx * delx + dely * dely + delz * delz; const DataT Rsq = delx * delx + dely * dely + delz * delz;
if (type_i != -1 && type_j != -1 && i != j) { if (type_i != -1 && type_j != -1 && i != j) {
DataT Rij = sqrt(Rsq); DataT Rij = sqrt(Rsq);
PairDist<DataT> d; PairDist<DataT> d;
d.Rij = Rij; d.Rij = Rij;
d.midx = mol_idx; d.midx = mol_idx;
d.i = i; d.i = i;
d.j = j; d.j = j;
d_Rij[mol_idx * natom_pairs + i * max_natoms_per_mol + j] = d; d_Rij[mol_idx * natom_pairs + i * max_natoms_per_mol + j] = d;
} }
} }
} }
} }
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4> template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
__global__ void cuAngularAEVs( __global__ void cuAngularAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t, torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT>* d_Rij, PairDist<DataT>* d_Rij,
PairDist<DataT>* d_centralAtom, PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom, int* d_nPairsPerCenterAtom,
int* d_centerAtomStartIdx, int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params, AEVScalarParams<DataT, IndexT> aev_params,
int maxnbrs_per_atom_aligned, int maxnbrs_per_atom_aligned,
int angular_length_aligned, int angular_length_aligned,
int ncentral_atoms) { int ncentral_atoms) {
extern __shared__ DataT smem[]; extern __shared__ DataT smem[];
int threads_per_catom = TILEX * TILEY; int threads_per_catom = TILEX * TILEY;
int gIdx = blockIdx.x * blockDim.x + threadIdx.x; int gIdx = blockIdx.x * blockDim.x + threadIdx.x;
int cIdx = gIdx / threads_per_catom; // central atom id int cIdx = gIdx / threads_per_catom; // central atom id
if (cIdx >= ncentral_atoms) if (cIdx >= ncentral_atoms)
return; return;
int groupIdx = threadIdx.x / threads_per_catom; int groupIdx = threadIdx.x / threads_per_catom;
int laneIdx = threadIdx.x % threads_per_catom; int laneIdx = threadIdx.x % threads_per_catom;
int ncatom_per_tpb = blockDim.x / threads_per_catom; int ncatom_per_tpb = blockDim.x / threads_per_catom;
DataT* saev = &smem[groupIdx * angular_length_aligned]; DataT* saev = &smem[groupIdx * angular_length_aligned];
int offset = ncatom_per_tpb * angular_length_aligned; int offset = ncatom_per_tpb * angular_length_aligned;
DataT* sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT* sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
int* stype = (int*)&smem[offset + groupIdx * maxnbrs_per_atom_aligned]; int* stype = (int*)&smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT EtaA = EtaA_t[0]; DataT EtaA = EtaA_t[0];
DataT Zeta = Zeta_t[0]; DataT Zeta = Zeta_t[0];
IndexT nShfA = ShfA_t.size(0); IndexT nShfA = ShfA_t.size(0);
IndexT nShfZ = ShfZ_t.size(0); IndexT nShfZ = ShfZ_t.size(0);
DataT Rca = aev_params.Rca; DataT Rca = aev_params.Rca;
IndexT num_species = aev_params.num_species; IndexT num_species = aev_params.num_species;
PairDist<DataT> d = d_centralAtom[cIdx]; PairDist<DataT> d = d_centralAtom[cIdx];
int start_idx = d_centerAtomStartIdx[cIdx]; int start_idx = d_centerAtomStartIdx[cIdx];
int jnum = d_nPairsPerCenterAtom[cIdx]; int jnum = d_nPairsPerCenterAtom[cIdx];
// center atom // center atom
int i = d.i; int i = d.i;
int mol_idx = d.midx; int mol_idx = d.midx;
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) { for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
saev[iaev] = 0; saev[iaev] = 0;
} }
DataT xi = pos_t[mol_idx][i][0]; DataT xi = pos_t[mol_idx][i][0];
DataT yi = pos_t[mol_idx][i][1]; DataT yi = pos_t[mol_idx][i][1];
DataT zi = pos_t[mol_idx][i][2]; DataT zi = pos_t[mol_idx][i][2];
for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) { for (int jj = laneIdx; jj < jnum; jj += threads_per_catom) {
PairDist<DataT> dij = d_Rij[start_idx + jj]; PairDist<DataT> dij = d_Rij[start_idx + jj];
int j = dij.j; int j = dij.j;
DataT Rij = dij.Rij; DataT Rij = dij.Rij;
SpeciesT type_j = species_t[mol_idx][j]; SpeciesT type_j = species_t[mol_idx][j];
sdx[jj] = pos_t[mol_idx][j][0] - xi; sdx[jj] = pos_t[mol_idx][j][0] - xi;
sdy[jj] = pos_t[mol_idx][j][1] - yi; sdy[jj] = pos_t[mol_idx][j][1] - yi;
sdz[jj] = pos_t[mol_idx][j][2] - zi; sdz[jj] = pos_t[mol_idx][j][2] - zi;
stype[jj] = type_j; stype[jj] = type_j;
sdist[jj] = Rij; sdist[jj] = Rij;
DataT fc_ij = 0.5 * cos(PI * Rij / Rca) + 0.5; DataT fc_ij = 0.5 * cos(PI * Rij / Rca) + 0.5;
sfc[jj] = fc_ij; sfc[jj] = fc_ij;
} }
short2 tile = make_short2(laneIdx % TILEX, laneIdx / TILEX); short2 tile = make_short2(laneIdx % TILEX, laneIdx / TILEX);
for (int jj = 0; jj < jnum; jj++) { for (int jj = 0; jj < jnum; jj++) {
const DataT Rij = sdist[jj]; const DataT Rij = sdist[jj];
SpeciesT type_j = stype[jj]; SpeciesT type_j = stype[jj];
DataT fc_ij = sfc[jj]; DataT fc_ij = sfc[jj];
for (int kk_start = jj + 1; kk_start < jnum; kk_start += threads_per_catom) { for (int kk_start = jj + 1; kk_start < jnum; kk_start += threads_per_catom) {
int kk = kk_start + laneIdx; int kk = kk_start + laneIdx;
DataT theta = 0; DataT theta = 0;
if (kk < jnum) { if (kk < jnum) {
const DataT Rik = sdist[kk]; const DataT Rik = sdist[kk];
theta = theta = acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik));
acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik)); }
}
for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) {
for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) { int kk = kk_start + srcLane;
int kk = kk_start + srcLane; DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane);
DataT theta_ijk = __shfl_sync(0xFFFFFFFF, theta, srcLane);
const DataT Rik = sdist[kk];
const DataT Rik = sdist[kk]; SpeciesT type_k = stype[kk];
SpeciesT type_k = stype[kk];
DataT fc_ik = sfc[kk];
DataT fc_ik = sfc[kk];
DataT Rijk = (Rij + Rik) / 2;
DataT Rijk = (Rij + Rik) / 2; DataT fc_ijk = fc_ij * fc_ik;
DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k];
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k]; IndexT aev_offset = aev_params.radial_length + subaev_offset;
IndexT aev_offset = aev_params.radial_length + subaev_offset;
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { DataT ShfZ = ShfZ_t[itheta];
DataT ShfZ = ShfZ_t[itheta];
DataT factor1 = pow((1 + cos(theta_ijk - ShfZ)) / 2, Zeta);
DataT factor1 = pow((1 + cos(theta_ijk - ShfZ)) / 2, Zeta);
for (int ishfr = tile.y; ishfr < nShfA; ishfr += TILEY) {
for (int ishfr = tile.y; ishfr < nShfA; ishfr += TILEY) { DataT ShfA = ShfA_t[ishfr];
DataT ShfA = ShfA_t[ishfr]; DataT factor2 = exp(-EtaA * (Rijk - ShfA) * (Rijk - ShfA));
DataT factor2 = exp(-EtaA * (Rijk - ShfA) * (Rijk - ShfA));
DataT res = 2 * factor1 * factor2 * fc_ijk;
DataT res = 2 * factor1 * factor2 * fc_ijk;
saev[subaev_offset + ishfr * nShfZ + itheta] += res;
saev[subaev_offset + ishfr * nShfZ + itheta] += res; }
} }
} }
} }
} }
}
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) { aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev];
aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev]; }
} }
}
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ> __global__ void cuRadialAEVs(
__global__ void cuRadialAEVs( torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t, PairDist<DataT>* d_Rij,
PairDist<DataT>* d_Rij, AEVScalarParams<DataT, int> aev_params,
AEVScalarParams<DataT, int> aev_params, int nRadialRij) {
int nRadialRij) { int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int gidx = blockIdx.x * blockDim.x + threadIdx.x; int idx = gidx / THREADS_PER_RIJ;
int idx = gidx / THREADS_PER_RIJ;
int nShfR = ShfR_t.size(0);
int nShfR = ShfR_t.size(0); DataT EtaR = EtaR_t[0];
DataT EtaR = EtaR_t[0];
if (idx >= nRadialRij)
if (idx >= nRadialRij) return;
return;
int laneIdx = threadIdx.x % THREADS_PER_RIJ;
int laneIdx = threadIdx.x % THREADS_PER_RIJ;
PairDist<DataT> d = d_Rij[idx];
PairDist<DataT> d = d_Rij[idx]; DataT Rij = d.Rij;
DataT Rij = d.Rij; int mol_idx = d.midx;
int mol_idx = d.midx; int i = d.i;
int i = d.i; int j = d.j;
int j = d.j;
SpeciesT type_i = species_t[mol_idx][i];
SpeciesT type_i = species_t[mol_idx][i]; SpeciesT type_j = species_t[mol_idx][j];
SpeciesT type_j = species_t[mol_idx][j];
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5;
DataT fc = 0.5 * cos(PI * Rij / aev_params.Rcr) + 0.5;
for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) {
for (int ishfr = laneIdx; ishfr < nShfR; ishfr += THREADS_PER_RIJ) { DataT ShfR = ShfR_t[ishfr];
DataT ShfR = ShfR_t[ishfr];
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc;
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc;
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], GmR);
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], GmR); }
} }
}
template <typename DataT>
template <typename DataT> void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) { auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
// Determine temporary device storage requirements void* d_temp_storage = NULL;
void* d_temp_storage = NULL; size_t temp_storage_bytes = 0;
size_t temp_storage_bytes = 0; cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
// Allocate temporary storage auto buffer_tmp = allocator.allocate(temp_storage_bytes);
auto buffer_tmp = allocator.allocate(temp_storage_bytes); d_temp_storage = buffer_tmp.get();
d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum
// Run exclusive prefix sum cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); }
}
template <typename DataT, typename IndexT>
template <typename DataT, typename IndexT> int cubEncode(
int cubEncode( const DataT* d_in,
const DataT* d_in, DataT* d_unique_out,
DataT* d_unique_out, IndexT* d_counts_out,
IndexT* d_counts_out, int num_items,
int num_items, int* d_num_runs_out,
int* d_num_runs_out, cudaStream_t stream) {
cudaStream_t stream) { auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
// Determine temporary device storage requirements void* d_temp_storage = NULL;
void* d_temp_storage = NULL; size_t temp_storage_bytes = 0;
size_t temp_storage_bytes = 0; cub::DeviceRunLengthEncode::Encode(
cub::DeviceRunLengthEncode::Encode( d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
d_temp_storage,
temp_storage_bytes, // Allocate temporary storage
d_in, auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_unique_out, d_temp_storage = buffer_tmp.get();
d_counts_out,
d_num_runs_out, // Run encoding
num_items, cub::DeviceRunLengthEncode::Encode(
stream); d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
// Allocate temporary storage int num_selected = 0;
auto buffer_tmp = allocator.allocate(temp_storage_bytes); cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream);
d_temp_storage = buffer_tmp.get(); cudaStreamSynchronize(stream);
return num_selected;
// Run encoding }
cub::DeviceRunLengthEncode::Encode(
d_temp_storage, template <typename DataT, typename LambdaOpT>
temp_storage_bytes, int cubDeviceSelect(
d_in, const DataT* d_in,
d_unique_out, DataT* d_out,
d_counts_out, int num_items,
d_num_runs_out, int* d_num_selected_out,
num_items, LambdaOpT select_op,
stream); cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream); // Determine temporary device storage requirements
cudaStreamSynchronize(stream); void* d_temp_storage = NULL;
return num_selected; size_t temp_storage_bytes = 0;
} cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
template <typename DataT, typename LambdaOpT> // Allocate temporary storage
int cubDeviceSelect( auto buffer_tmp = allocator.allocate(temp_storage_bytes);
const DataT* d_in, d_temp_storage = buffer_tmp.get();
DataT* d_out,
int num_items, // Run selection
int* d_num_selected_out, cub::DeviceSelect::If(
LambdaOpT select_op, d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream);
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream);
// Determine temporary device storage requirements cudaStreamSynchronize(stream);
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; return num_selected;
cub::DeviceSelect::If( }
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
template <typename DataT>
// Allocate temporary storage DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) {
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
d_temp_storage = buffer_tmp.get(); // Determine temporary device storage requirements
void* d_temp_storage = NULL;
// Run selection size_t temp_storage_bytes = 0;
cub::DeviceSelect::If( cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
d_temp_storage,
temp_storage_bytes, // Allocate temporary storage
d_in, auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_out, d_temp_storage = buffer_tmp.get();
d_num_selected_out,
num_items, // Run min-reduction
select_op, cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
stream);
int maxVal = 0;
int num_selected = 0; cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream); cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream);
return maxVal;
return num_selected; }
}
void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
template <typename DataT> int num_species = aev_params.num_species;
DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) { assert(num_species <= MAX_NSPECIES);
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); // precompute the aev offsets and load to constand memory
// Determine temporary device storage requirements int* subaev_offsets = new int[num_species * num_species];
void* d_temp_storage = NULL; for (int t = 0; t < num_species; ++t) {
size_t temp_storage_bytes = 0; int offset = 0;
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); for (int s = 0; s < num_species; s++) {
if (t < num_species - s) {
// Allocate temporary storage subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t);
auto buffer_tmp = allocator.allocate(temp_storage_bytes); subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t);
d_temp_storage = buffer_tmp.get(); }
offset += num_species - s;
// Run min-reduction }
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); }
cudaMemcpyToSymbolAsync(
int maxVal = 0; csubaev_offsets, subaev_offsets, sizeof(int) * num_species * num_species, 0, cudaMemcpyDefault, stream);
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream); delete[] subaev_offsets;
cudaStreamSynchronize(stream); }
return maxVal; // NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
} template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV(
void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) { torch::Tensor coordinates_t,
int num_species = aev_params.num_species; torch::Tensor species_t,
assert(num_species <= MAX_NSPECIES); double Rcr_,
// precompute the aev offsets and load to constand memory double Rca_,
int* subaev_offsets = new int[num_species * num_species]; torch::Tensor EtaR_t,
for (int t = 0; t < num_species; ++t) { torch::Tensor ShfR_t,
int offset = 0; torch::Tensor EtaA_t,
for (int s = 0; s < num_species; s++) { torch::Tensor Zeta_t,
if (t < num_species - s) { torch::Tensor ShfA_t,
subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t); torch::Tensor ShfZ_t,
subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t); int64_t num_species_) {
} TORCH_CHECK(
offset += num_species - s; (species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
} TORCH_CHECK(
} EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1,
cudaMemcpyToSymbolAsync( "cuda extension is currently not supported for the specified "
csubaev_offsets, "configuration");
subaev_offsets,
sizeof(int) * num_species * num_species, ScalarRealT Rcr = Rcr_;
0, ScalarRealT Rca = Rca_;
cudaMemcpyDefault, int num_species = num_species_;
stream);
delete[] subaev_offsets; const int n_molecules = species_t.size(0);
} const int max_natoms_per_mol = species_t.size(1);
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1 AEVScalarParams<float> aev_params;
template <typename ScalarRealT = float> aev_params.Rca = Rca;
torch::Tensor cuComputeAEV( aev_params.Rcr = Rcr;
torch::Tensor coordinates_t, aev_params.num_species = num_species;
torch::Tensor species_t,
double Rcr_, aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0);
double Rca_, aev_params.radial_length = aev_params.radial_sublength * num_species;
torch::Tensor EtaR_t,
torch::Tensor ShfR_t, aev_params.angular_sublength = EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0);
torch::Tensor EtaA_t, aev_params.angular_length = aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
torch::Tensor Zeta_t,
torch::Tensor ShfA_t, int aev_length = aev_params.radial_length + aev_params.angular_length;
torch::Tensor ShfZ_t,
int64_t num_species_) { auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options());
TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), if (species_t.numel() == 0) {
"Unsupported input type"); return aev_t;
TORCH_CHECK( }
EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified " cudaStream_t stream = at::cuda::getCurrentCUDAStream();
"configuration"); auto thrust_allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
ScalarRealT Rcr = Rcr_; auto& allocator = *c10::cuda::CUDACachingAllocator::get();
ScalarRealT Rca = Rca_;
int num_species = num_species_; // precompute the aev offsets and load to constand memory
initConsts(aev_params, stream);
const int n_molecules = species_t.size(0);
const int max_natoms_per_mol = species_t.size(1); // buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
AEVScalarParams<float> aev_params; auto buffer_Rij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
aev_params.Rca = Rca; PairDist<float>* d_Rij = (PairDist<float>*)buffer_Rij.get();
aev_params.Rcr = Rcr;
aev_params.num_species = num_species; // init all Rij to inf
PairDist<float> init;
aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0); init.Rij = std::numeric_limits<float>::infinity();
aev_params.radial_length = aev_params.radial_sublength * num_species; thrust::fill(policy, d_Rij, d_Rij + total_natom_pairs, init);
aev_params.angular_sublength = EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0); // buffer to store all the pairwise distance that is needed for Radial AEV
aev_params.angular_length = aev_params.angular_sublength * (num_species * (num_species + 1) / 2); // computation
auto buffer_radialRij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
int aev_length = aev_params.radial_length + aev_params.angular_length; PairDist<float>* d_radialRij = (PairDist<float>*)buffer_radialRij.get();
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options()); auto buffer_count = allocator.allocate(sizeof(int));
int* d_count_out = (int*)buffer_count.get();
if (species_t.numel() == 0) {
return aev_t; const int block_size = 64;
}
dim3 block(8, 8, 1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Compute pairwise distance (Rij) for all atom pairs in a molecule
auto thrust_allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA()); pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
auto policy = thrust::cuda::par(thrust_allocator).on(stream); species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
// precompute the aev offsets and load to constand memory max_natoms_per_mol);
initConsts(aev_params, stream);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
// buffer to store all the pairwise distance (Rij) // Rcr
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol; int nRadialRij = cubDeviceSelect(
auto buffer_Rij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs); d_Rij,
PairDist<float>* d_Rij = (PairDist<float>*)buffer_Rij.get(); d_radialRij,
total_natom_pairs,
// init all Rij to inf d_count_out,
PairDist<float> init; [=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; },
init.Rij = std::numeric_limits<float>::infinity(); stream);
thrust::fill(policy, d_Rij, d_Rij + total_natom_pairs, init);
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
// buffer to store all the pairwise distance that is needed for Radial AEV cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>(
// computation species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
auto buffer_radialRij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs); ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
PairDist<float>* d_radialRij = (PairDist<float>*)buffer_radialRij.get(); EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
auto buffer_count = allocator.allocate(sizeof(int)); d_radialRij,
int* d_count_out = (int*)buffer_count.get(); aev_params,
nRadialRij);
const int block_size = 64;
// reuse buffer allocated for all Rij
dim3 block(8, 8, 1); // d_angularRij will store all the Rij required in Angular AEV computation
// Compute pairwise distance (Rij) for all atom pairs in a molecule PairDist<float>* d_angularRij = d_Rij;
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), // Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), // <= Rca
d_Rij, int nAngularRij = cubDeviceSelect(
max_natoms_per_mol); d_radialRij,
d_angularRij,
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= nRadialRij,
// Rcr d_count_out,
int nRadialRij = cubDeviceSelect( [=] __device__(const PairDist<float> d) { return d.Rij <= Rca; },
d_Rij, stream);
d_radialRij,
total_natom_pairs, auto buffer_centralAtom = allocator.allocate(sizeof(PairDist<float>) * nAngularRij);
d_count_out, PairDist<float>* d_centralAtom = (PairDist<float>*)buffer_centralAtom.get();
[=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; },
stream); auto buffer_numPairsPerCenterAtom = allocator.allocate(sizeof(int) * nAngularRij);
int* d_numPairsPerCenterAtom = (int*)buffer_numPairsPerCenterAtom.get();
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>( // group by center atom
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), int ncenter_atoms = cubEncode(d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream);
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), auto buffer_centerAtomStartIdx = allocator.allocate(sizeof(int) * ncenter_atoms);
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), int* d_centerAtomStartIdx = (int*)buffer_centerAtomStartIdx.get();
d_radialRij,
aev_params, cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream);
nRadialRij); {
const int nthreads_per_catom = 32;
// reuse buffer allocated for all Rij const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
// d_angularRij will store all the Rij required in Angular AEV computation auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
PairDist<float>* d_angularRij = d_Rij; int sm_aev = sizeof(float) * align<4>(aev_params.angular_length);
int sxyz = sizeof(float) * max_nbrs * 3;
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij int sRij = sizeof(float) * max_nbrs;
// <= Rca int sfc = sizeof(float) * max_nbrs;
int nAngularRij = cubDeviceSelect( int sj = sizeof(int) * max_nbrs;
d_radialRij,
d_angularRij, return (sm_aev + sxyz + sRij + sfc + sj) * ncatom_per_tpb;
nRadialRij, };
d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; }, int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
stream);
int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom);
auto buffer_centralAtom = allocator.allocate(sizeof(PairDist<float>) * nAngularRij);
PairDist<float>* d_centralAtom = (PairDist<float>*)buffer_centralAtom.get(); cuAngularAEVs<<<
nblocks_angAEV,
auto buffer_numPairsPerCenterAtom = allocator.allocate(sizeof(int) * nAngularRij); block_size,
int* d_numPairsPerCenterAtom = (int*)buffer_numPairsPerCenterAtom.get(); smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom),
stream>>>(
// group by center atom species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
int ncenter_atoms = cubEncode( coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream); ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
auto buffer_centerAtomStartIdx = allocator.allocate(sizeof(int) * ncenter_atoms); EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
int* d_centerAtomStartIdx = (int*)buffer_centerAtomStartIdx.get(); Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream); d_angularRij,
{ d_centralAtom,
const int nthreads_per_catom = 32; d_numPairsPerCenterAtom,
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size; d_centerAtomStartIdx,
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { aev_params,
int sm_aev = sizeof(float) * align<4>(aev_params.angular_length); maxnbrs_per_atom_aligned,
int sxyz = sizeof(float) * max_nbrs * 3; align<4>(aev_params.angular_length),
int sRij = sizeof(float) * max_nbrs; ncenter_atoms);
int sfc = sizeof(float) * max_nbrs; }
int sj = sizeof(int) * max_nbrs; return aev_t;
}
return (sm_aev + sxyz + sRij + sfc + sj) * ncatom_per_tpb;
}; TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", &cuComputeAEV<float>);
int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream); }
int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
cuAngularAEVs<<<
nblocks_angAEV,
block_size,
smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom),
stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
ShfZ_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij,
d_centralAtom,
d_numPairsPerCenterAtom,
d_centerAtomStartIdx,
aev_params,
maxnbrs_per_atom_aligned,
align<4>(aev_params.angular_length),
ncenter_atoms);
}
return aev_t;
}
TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", &cuComputeAEV<float>);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment