Unverified Commit a6d819ed authored by Richard Xue's avatar Richard Xue Committed by GitHub
Browse files

clang-format line-limit to 120 (#552)

* clang-format

* line-limit 120
parent 5ff2f8fc
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
#include <cub/cub.cuh>
#include <thrust/equal.h> #include <thrust/equal.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <THC/THCThrustAllocator.cuh>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793 #define PI 3.141592653589793
template <typename DataT, typename IndexT = int> struct AEVScalarParams { template <typename DataT, typename IndexT = int>
struct AEVScalarParams {
DataT Rcr; DataT Rcr;
DataT Rca; DataT Rca;
...@@ -23,7 +24,8 @@ template <typename DataT, typename IndexT = int> struct AEVScalarParams { ...@@ -23,7 +24,8 @@ template <typename DataT, typename IndexT = int> struct AEVScalarParams {
#define MAX_NSPECIES 10 #define MAX_NSPECIES 10
__constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES]; __constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES];
template <typename DataT> struct PairDist { template <typename DataT>
struct PairDist {
DataT Rij; DataT Rij;
int midx; int midx;
short i; short i;
...@@ -32,8 +34,7 @@ template <typename DataT> struct PairDist { ...@@ -32,8 +34,7 @@ template <typename DataT> struct PairDist {
// used to group Rijs by atom id // used to group Rijs by atom id
template <typename DataT> template <typename DataT>
__host__ __device__ bool operator==(const PairDist<DataT> &lhs, __host__ __device__ bool operator==(const PairDist<DataT>& lhs, const PairDist<DataT>& rhs) {
const PairDist<DataT> &rhs) {
return lhs.midx == rhs.midx && lhs.i == rhs.i; return lhs.midx == rhs.midx && lhs.i == rhs.i;
} }
...@@ -42,23 +43,21 @@ __host__ __device__ bool operator==(const PairDist<DataT> &lhs, ...@@ -42,23 +43,21 @@ __host__ __device__ bool operator==(const PairDist<DataT> &lhs,
/// \param value Input value that is to be aligned /// \param value Input value that is to be aligned
/// \return Value aligned to boundary /// \return Value aligned to boundary
template <int32_t boundary> template <int32_t boundary>
__host__ __device__ __forceinline__ int align(const int &value) { __host__ __device__ __forceinline__ int align(const int& value) {
static_assert((boundary & (boundary - 1)) == 0, static_assert((boundary & (boundary - 1)) == 0, "Boundary for align must be power of 2");
"Boundary for align must be power of 2");
return (value + boundary) & ~(boundary - 1); return (value + boundary) & ~(boundary - 1);
} }
template <typename SpeciesT, typename DataT, typename IndexT = int> template <typename SpeciesT, typename DataT, typename IndexT = int>
__global__ void pairwiseDistance( __global__ void pairwiseDistance(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
PairDist<DataT> *d_Rij, IndexT max_natoms_per_mol) { PairDist<DataT>* d_Rij,
IndexT max_natoms_per_mol) {
extern __shared__ DataT spos[]; extern __shared__ DataT spos[];
DataT *sx = &spos[0]; DataT* sx = &spos[0];
DataT *sy = &spos[max_natoms_per_mol]; DataT* sy = &spos[max_natoms_per_mol];
DataT *sz = &spos[2 * max_natoms_per_mol]; DataT* sz = &spos[2 * max_natoms_per_mol];
int mol_idx = blockIdx.x; int mol_idx = blockIdx.x;
int tidx = threadIdx.y * blockDim.x + threadIdx.x; int tidx = threadIdx.y * blockDim.x + threadIdx.x;
...@@ -74,7 +73,6 @@ __global__ void pairwiseDistance( ...@@ -74,7 +73,6 @@ __global__ void pairwiseDistance(
int natom_pairs = max_natoms_per_mol * max_natoms_per_mol; int natom_pairs = max_natoms_per_mol * max_natoms_per_mol;
for (int i = threadIdx.y; i < max_natoms_per_mol; i += blockDim.y) { for (int i = threadIdx.y; i < max_natoms_per_mol; i += blockDim.y) {
SpeciesT type_i = species_t[mol_idx][i]; SpeciesT type_i = species_t[mol_idx][i];
DataT xi = sx[i]; DataT xi = sx[i];
...@@ -107,21 +105,23 @@ __global__ void pairwiseDistance( ...@@ -107,21 +105,23 @@ __global__ void pairwiseDistance(
} }
} }
template <typename SpeciesT, typename DataT, typename IndexT = int, template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
int TILEX = 8, int TILEY = 4>
__global__ void cuAngularAEVs( __global__ void cuAngularAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT> *d_Rij, PairDist<DataT> *d_centralAtom, PairDist<DataT>* d_Rij,
int *d_nPairsPerCenterAtom, int *d_centerAtomStartIdx, PairDist<DataT>* d_centralAtom,
AEVScalarParams<DataT, IndexT> aev_params, int maxnbrs_per_atom_aligned, int* d_nPairsPerCenterAtom,
int angular_length_aligned, int ncentral_atoms) { int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncentral_atoms) {
extern __shared__ DataT smem[]; extern __shared__ DataT smem[];
int threads_per_catom = TILEX * TILEY; int threads_per_catom = TILEX * TILEY;
...@@ -135,25 +135,25 @@ __global__ void cuAngularAEVs( ...@@ -135,25 +135,25 @@ __global__ void cuAngularAEVs(
int laneIdx = threadIdx.x % threads_per_catom; int laneIdx = threadIdx.x % threads_per_catom;
int ncatom_per_tpb = blockDim.x / threads_per_catom; int ncatom_per_tpb = blockDim.x / threads_per_catom;
DataT *saev = &smem[groupIdx * angular_length_aligned]; DataT* saev = &smem[groupIdx * angular_length_aligned];
int offset = ncatom_per_tpb * angular_length_aligned; int offset = ncatom_per_tpb * angular_length_aligned;
DataT *sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned]; DataT* sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned; offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
int *stype = (int *)&smem[offset + groupIdx * maxnbrs_per_atom_aligned]; int* stype = (int*)&smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT EtaA = EtaA_t[0]; DataT EtaA = EtaA_t[0];
DataT Zeta = Zeta_t[0]; DataT Zeta = Zeta_t[0];
...@@ -171,8 +171,7 @@ __global__ void cuAngularAEVs( ...@@ -171,8 +171,7 @@ __global__ void cuAngularAEVs(
int i = d.i; int i = d.i;
int mol_idx = d.midx; int mol_idx = d.midx;
for (int iaev = laneIdx; iaev < aev_params.angular_length; for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
iaev += threads_per_catom) {
saev[iaev] = 0; saev[iaev] = 0;
} }
...@@ -202,16 +201,13 @@ __global__ void cuAngularAEVs( ...@@ -202,16 +201,13 @@ __global__ void cuAngularAEVs(
DataT fc_ij = sfc[jj]; DataT fc_ij = sfc[jj];
for (int kk_start = jj + 1; kk_start < jnum; for (int kk_start = jj + 1; kk_start < jnum; kk_start += threads_per_catom) {
kk_start += threads_per_catom) {
int kk = kk_start + laneIdx; int kk = kk_start + laneIdx;
DataT theta = 0; DataT theta = 0;
if (kk < jnum) { if (kk < jnum) {
const DataT Rik = sdist[kk]; const DataT Rik = sdist[kk];
theta = acos( theta =
0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik));
(Rij * Rik));
} }
for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) { for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) {
...@@ -247,22 +243,20 @@ __global__ void cuAngularAEVs( ...@@ -247,22 +243,20 @@ __global__ void cuAngularAEVs(
} }
} }
for (int iaev = laneIdx; iaev < aev_params.angular_length; for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
iaev += threads_per_catom) {
aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev]; aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev];
} }
} }
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ> template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs( __global__ void cuRadialAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
species_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t, torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t, torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT> *d_Rij, AEVScalarParams<DataT, int> aev_params, PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params,
int nRadialRij) { int nRadialRij) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x; int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = gidx / THREADS_PER_RIJ; int idx = gidx / THREADS_PER_RIJ;
...@@ -290,104 +284,123 @@ __global__ void cuRadialAEVs( ...@@ -290,104 +284,123 @@ __global__ void cuRadialAEVs(
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc; DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc;
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], GmR);
GmR);
} }
} }
template <typename DataT> template <typename DataT>
void cubScan(const DataT *d_in, DataT *d_out, int num_items, void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
cudaStream_t stream) { auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
num_items, stream);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum // Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
num_items, stream);
} }
template <typename DataT, typename IndexT> template <typename DataT, typename IndexT>
int cubEncode(const DataT *d_in, DataT *d_unique_out, IndexT *d_counts_out, int cubEncode(
int num_items, int *d_num_runs_out, cudaStream_t stream) { const DataT* d_in,
auto &allocator = *c10::cuda::CUDACachingAllocator::get(); DataT* d_unique_out,
IndexT* d_counts_out,
int num_items,
int* d_num_runs_out,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in, cub::DeviceRunLengthEncode::Encode(
d_unique_out, d_counts_out, d_num_runs_out, d_temp_storage,
num_items, stream); temp_storage_bytes,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
num_items,
stream);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run encoding // Run encoding
cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in, cub::DeviceRunLengthEncode::Encode(
d_unique_out, d_counts_out, d_num_runs_out, d_temp_storage,
num_items, stream); temp_storage_bytes,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
num_items,
stream);
int num_selected = 0; int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream);
stream);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
return num_selected; return num_selected;
} }
template <typename DataT, typename LambdaOpT> template <typename DataT, typename LambdaOpT>
int cubDeviceSelect(const DataT *d_in, DataT *d_out, int num_items, int cubDeviceSelect(
int *d_num_selected_out, LambdaOpT select_op, const DataT* d_in,
DataT* d_out,
int num_items,
int* d_num_selected_out,
LambdaOpT select_op,
cudaStream_t stream) { cudaStream_t stream) {
auto &allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceSelect::If(
d_num_selected_out, num_items, select_op); d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run selection // Run selection
cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceSelect::If(
d_num_selected_out, num_items, select_op, stream); d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
d_num_selected_out,
num_items,
select_op,
stream);
int num_selected = 0; int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream);
cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
return num_selected; return num_selected;
} }
template <typename DataT> template <typename DataT>
DataT cubMax(const DataT *d_in, int num_items, DataT *d_out, DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) {
cudaStream_t stream) { auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
num_items, stream);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run min-reduction // Run min-reduction
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
num_items, stream);
int maxVal = 0; int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream); cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
...@@ -396,40 +409,50 @@ DataT cubMax(const DataT *d_in, int num_items, DataT *d_out, ...@@ -396,40 +409,50 @@ DataT cubMax(const DataT *d_in, int num_items, DataT *d_out,
return maxVal; return maxVal;
} }
void initConsts(AEVScalarParams<float> &aev_params, cudaStream_t stream) { void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
int num_species = aev_params.num_species; int num_species = aev_params.num_species;
assert(num_species <= MAX_NSPECIES); assert(num_species <= MAX_NSPECIES);
// precompute the aev offsets and load to constand memory // precompute the aev offsets and load to constand memory
int *subaev_offsets = new int[num_species * num_species]; int* subaev_offsets = new int[num_species * num_species];
for (int t = 0; t < num_species; ++t) { for (int t = 0; t < num_species; ++t) {
int offset = 0; int offset = 0;
for (int s = 0; s < num_species; s++) { for (int s = 0; s < num_species; s++) {
if (t < num_species - s) { if (t < num_species - s) {
subaev_offsets[s * num_species + s + t] = subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t);
aev_params.angular_sublength * (offset + t); subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t);
subaev_offsets[(s + t) * num_species + s] =
aev_params.angular_sublength * (offset + t);
} }
offset += num_species - s; offset += num_species - s;
} }
} }
cudaMemcpyToSymbolAsync(csubaev_offsets, subaev_offsets, cudaMemcpyToSymbolAsync(
sizeof(int) * num_species * num_species, 0, csubaev_offsets,
cudaMemcpyDefault, stream); subaev_offsets,
sizeof(int) * num_species * num_species,
0,
cudaMemcpyDefault,
stream);
delete[] subaev_offsets; delete[] subaev_offsets;
} }
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1 // NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
template <typename ScalarRealT = float> template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, torch::Tensor cuComputeAEV(
double Rcr_, double Rca_, torch::Tensor EtaR_t, torch::Tensor coordinates_t,
torch::Tensor ShfR_t, torch::Tensor EtaA_t, torch::Tensor species_t,
torch::Tensor Zeta_t, torch::Tensor ShfA_t, double Rcr_,
torch::Tensor ShfZ_t, int64_t num_species_) { double Rca_,
TORCH_CHECK((species_t.dtype() == torch::kInt32) && torch::Tensor EtaR_t,
(coordinates_t.dtype() == torch::kFloat32), torch::Tensor ShfR_t,
torch::Tensor EtaA_t,
torch::Tensor Zeta_t,
torch::Tensor ShfA_t,
torch::Tensor ShfZ_t,
int64_t num_species_) {
TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32),
"Unsupported input type"); "Unsupported input type");
TORCH_CHECK(EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1, TORCH_CHECK(
EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified " "cuda extension is currently not supported for the specified "
"configuration"); "configuration");
...@@ -448,35 +471,29 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -448,35 +471,29 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0); aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0);
aev_params.radial_length = aev_params.radial_sublength * num_species; aev_params.radial_length = aev_params.radial_sublength * num_species;
aev_params.angular_sublength = aev_params.angular_sublength = EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0);
EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0); aev_params.angular_length = aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
aev_params.angular_length =
aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
int aev_length = aev_params.radial_length + aev_params.angular_length; int aev_length = aev_params.radial_length + aev_params.angular_length;
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options());
coordinates_t.options());
if (species_t.numel() == 0) { if (species_t.numel() == 0) {
return aev_t; return aev_t;
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto thrust_allocator = auto thrust_allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(thrust_allocator).on(stream); auto policy = thrust::cuda::par(thrust_allocator).on(stream);
auto &allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// precompute the aev offsets and load to constand memory // precompute the aev offsets and load to constand memory
initConsts(aev_params, stream); initConsts(aev_params, stream);
// buffer to store all the pairwise distance (Rij) // buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
n_molecules * max_natoms_per_mol * max_natoms_per_mol; auto buffer_Rij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
auto buffer_Rij = PairDist<float>* d_Rij = (PairDist<float>*)buffer_Rij.get();
allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
PairDist<float> *d_Rij = (PairDist<float> *)buffer_Rij.get();
// init all Rij to inf // init all Rij to inf
PairDist<float> init; PairDist<float> init;
...@@ -485,28 +502,31 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -485,28 +502,31 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
// buffer to store all the pairwise distance that is needed for Radial AEV // buffer to store all the pairwise distance that is needed for Radial AEV
// computation // computation
auto buffer_radialRij = auto buffer_radialRij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs); PairDist<float>* d_radialRij = (PairDist<float>*)buffer_radialRij.get();
PairDist<float> *d_radialRij = (PairDist<float> *)buffer_radialRij.get();
auto buffer_count = allocator.allocate(sizeof(int)); auto buffer_count = allocator.allocate(sizeof(int));
int *d_count_out = (int *)buffer_count.get(); int* d_count_out = (int*)buffer_count.get();
const int block_size = 64; const int block_size = 64;
dim3 block(8, 8, 1); dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule // Compute pairwise distance (Rij) for all atom pairs in a molecule
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij, max_natoms_per_mol); d_Rij,
max_natoms_per_mol);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= // Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
// Rcr // Rcr
int nRadialRij = cubDeviceSelect( int nRadialRij = cubDeviceSelect(
d_Rij, d_radialRij, total_natom_pairs, d_count_out, d_Rij,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; }, stream); d_radialRij,
total_natom_pairs,
d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; },
stream);
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size; int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>( cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>(
...@@ -514,40 +534,41 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -514,40 +534,41 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, aev_params, nRadialRij); d_radialRij,
aev_params,
nRadialRij);
// reuse buffer allocated for all Rij // reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation // d_angularRij will store all the Rij required in Angular AEV computation
PairDist<float> *d_angularRij = d_Rij; PairDist<float>* d_angularRij = d_Rij;
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij // Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
// <= Rca // <= Rca
int nAngularRij = cubDeviceSelect( int nAngularRij = cubDeviceSelect(
d_radialRij, d_angularRij, nRadialRij, d_count_out, d_radialRij,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; }, stream); d_angularRij,
nRadialRij,
d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; },
stream);
auto buffer_centralAtom = auto buffer_centralAtom = allocator.allocate(sizeof(PairDist<float>) * nAngularRij);
allocator.allocate(sizeof(PairDist<float>) * nAngularRij); PairDist<float>* d_centralAtom = (PairDist<float>*)buffer_centralAtom.get();
PairDist<float> *d_centralAtom = (PairDist<float> *)buffer_centralAtom.get();
auto buffer_numPairsPerCenterAtom = auto buffer_numPairsPerCenterAtom = allocator.allocate(sizeof(int) * nAngularRij);
allocator.allocate(sizeof(int) * nAngularRij); int* d_numPairsPerCenterAtom = (int*)buffer_numPairsPerCenterAtom.get();
int *d_numPairsPerCenterAtom = (int *)buffer_numPairsPerCenterAtom.get();
// group by center atom // group by center atom
int ncenter_atoms = int ncenter_atoms = cubEncode(
cubEncode(d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream);
nAngularRij, d_count_out, stream);
auto buffer_centerAtomStartIdx = auto buffer_centerAtomStartIdx = allocator.allocate(sizeof(int) * ncenter_atoms);
allocator.allocate(sizeof(int) * ncenter_atoms); int* d_centerAtomStartIdx = (int*)buffer_centerAtomStartIdx.get();
int *d_centerAtomStartIdx = (int *)buffer_centerAtomStartIdx.get();
cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream); cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream);
{ {
const int nthreads_per_catom = 32; const int nthreads_per_catom = 32;
const int nblocks_angAEV = const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
(ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sm_aev = sizeof(float) * align<4>(aev_params.angular_length); int sm_aev = sizeof(float) * align<4>(aev_params.angular_length);
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
...@@ -558,14 +579,14 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -558,14 +579,14 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
return (sm_aev + sxyz + sRij + sfc + sj) * ncatom_per_tpb; return (sm_aev + sxyz + sRij + sfc + sj) * ncatom_per_tpb;
}; };
int maxNbrsPerCenterAtom = int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom); int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom);
cuAngularAEVs<<<nblocks_angAEV, block_size, cuAngularAEVs<<<
smem_size(maxnbrs_per_atom_aligned, nblocks_angAEV,
block_size / nthreads_per_catom), block_size,
smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom),
stream>>>( stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
...@@ -574,13 +595,20 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -574,13 +595,20 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(), Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, d_angularRij,
d_centerAtomStartIdx, aev_params, maxnbrs_per_atom_aligned, d_centralAtom,
align<4>(aev_params.angular_length), ncenter_atoms); d_numPairsPerCenterAtom,
d_centerAtomStartIdx,
aev_params,
maxnbrs_per_atom_aligned,
align<4>(aev_params.angular_length),
ncenter_atoms);
} }
return aev_t; return aev_t;
} }
TORCH_LIBRARY(cuaev, m) { m.def("cuComputeAEV", &cuComputeAEV<float>); } TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", &cuComputeAEV<float>);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment