Unverified Commit a6d819ed authored by Richard Xue's avatar Richard Xue Committed by GitHub
Browse files

clang-format line-limit to 120 (#552)

* clang-format

* line-limit 120
parent 5ff2f8fc
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
#include <cub/cub.cuh>
#include <thrust/equal.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <ATen/Context.h>
#include <THC/THC.h>
#include <THC/THCThrustAllocator.cuh>
#include <c10/cuda/CUDACachingAllocator.h>
#include <THC/THCThrustAllocator.cuh>
#define PI 3.141592653589793
template <typename DataT, typename IndexT = int> struct AEVScalarParams {
template <typename DataT, typename IndexT = int>
struct AEVScalarParams {
DataT Rcr;
DataT Rca;
......@@ -23,7 +24,8 @@ template <typename DataT, typename IndexT = int> struct AEVScalarParams {
#define MAX_NSPECIES 10
__constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES];
template <typename DataT> struct PairDist {
template <typename DataT>
struct PairDist {
DataT Rij;
int midx;
short i;
......@@ -32,8 +34,7 @@ template <typename DataT> struct PairDist {
// used to group Rijs by atom id
template <typename DataT>
__host__ __device__ bool operator==(const PairDist<DataT> &lhs,
const PairDist<DataT> &rhs) {
__host__ __device__ bool operator==(const PairDist<DataT>& lhs, const PairDist<DataT>& rhs) {
return lhs.midx == rhs.midx && lhs.i == rhs.i;
}
......@@ -42,23 +43,21 @@ __host__ __device__ bool operator==(const PairDist<DataT> &lhs,
/// \param value Input value that is to be aligned
/// \return Value aligned to boundary
template <int32_t boundary>
__host__ __device__ __forceinline__ int align(const int &value) {
static_assert((boundary & (boundary - 1)) == 0,
"Boundary for align must be power of 2");
__host__ __device__ __forceinline__ int align(const int& value) {
static_assert((boundary & (boundary - 1)) == 0, "Boundary for align must be power of 2");
return (value + boundary) & ~(boundary - 1);
}
template <typename SpeciesT, typename DataT, typename IndexT = int>
__global__ void pairwiseDistance(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits>
species_t,
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
PairDist<DataT> *d_Rij, IndexT max_natoms_per_mol) {
PairDist<DataT>* d_Rij,
IndexT max_natoms_per_mol) {
extern __shared__ DataT spos[];
DataT *sx = &spos[0];
DataT *sy = &spos[max_natoms_per_mol];
DataT *sz = &spos[2 * max_natoms_per_mol];
DataT* sx = &spos[0];
DataT* sy = &spos[max_natoms_per_mol];
DataT* sz = &spos[2 * max_natoms_per_mol];
int mol_idx = blockIdx.x;
int tidx = threadIdx.y * blockDim.x + threadIdx.x;
......@@ -74,7 +73,6 @@ __global__ void pairwiseDistance(
int natom_pairs = max_natoms_per_mol * max_natoms_per_mol;
for (int i = threadIdx.y; i < max_natoms_per_mol; i += blockDim.y) {
SpeciesT type_i = species_t[mol_idx][i];
DataT xi = sx[i];
......@@ -107,21 +105,23 @@ __global__ void pairwiseDistance(
}
}
template <typename SpeciesT, typename DataT, typename IndexT = int,
int TILEX = 8, int TILEY = 4>
template <typename SpeciesT, typename DataT, typename IndexT = int, int TILEX = 8, int TILEY = 4>
__global__ void cuAngularAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits>
species_t,
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfZ_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaA_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> Zeta_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT> *d_Rij, PairDist<DataT> *d_centralAtom,
int *d_nPairsPerCenterAtom, int *d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params, int maxnbrs_per_atom_aligned,
int angular_length_aligned, int ncentral_atoms) {
PairDist<DataT>* d_Rij,
PairDist<DataT>* d_centralAtom,
int* d_nPairsPerCenterAtom,
int* d_centerAtomStartIdx,
AEVScalarParams<DataT, IndexT> aev_params,
int maxnbrs_per_atom_aligned,
int angular_length_aligned,
int ncentral_atoms) {
extern __shared__ DataT smem[];
int threads_per_catom = TILEX * TILEY;
......@@ -135,25 +135,25 @@ __global__ void cuAngularAEVs(
int laneIdx = threadIdx.x % threads_per_catom;
int ncatom_per_tpb = blockDim.x / threads_per_catom;
DataT *saev = &smem[groupIdx * angular_length_aligned];
DataT* saev = &smem[groupIdx * angular_length_aligned];
int offset = ncatom_per_tpb * angular_length_aligned;
DataT *sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT* sdx = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT* sdy = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT* sdz = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT* sdist = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
DataT *sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT* sfc = &smem[offset + groupIdx * maxnbrs_per_atom_aligned];
offset += ncatom_per_tpb * maxnbrs_per_atom_aligned;
int *stype = (int *)&smem[offset + groupIdx * maxnbrs_per_atom_aligned];
int* stype = (int*)&smem[offset + groupIdx * maxnbrs_per_atom_aligned];
DataT EtaA = EtaA_t[0];
DataT Zeta = Zeta_t[0];
......@@ -171,8 +171,7 @@ __global__ void cuAngularAEVs(
int i = d.i;
int mol_idx = d.midx;
for (int iaev = laneIdx; iaev < aev_params.angular_length;
iaev += threads_per_catom) {
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
saev[iaev] = 0;
}
......@@ -202,16 +201,13 @@ __global__ void cuAngularAEVs(
DataT fc_ij = sfc[jj];
for (int kk_start = jj + 1; kk_start < jnum;
kk_start += threads_per_catom) {
for (int kk_start = jj + 1; kk_start < jnum; kk_start += threads_per_catom) {
int kk = kk_start + laneIdx;
DataT theta = 0;
if (kk < jnum) {
const DataT Rik = sdist[kk];
theta = acos(
0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) /
(Rij * Rik));
theta =
acos(0.95 * (sdx[jj] * sdx[kk] + sdy[jj] * sdy[kk] + sdz[jj] * sdz[kk]) / (Rij * Rik));
}
for (int srcLane = 0; kk_start + srcLane < min(32, jnum); ++srcLane) {
......@@ -247,22 +243,20 @@ __global__ void cuAngularAEVs(
}
}
for (int iaev = laneIdx; iaev < aev_params.angular_length;
iaev += threads_per_catom) {
for (int iaev = laneIdx; iaev < aev_params.angular_length; iaev += threads_per_catom) {
aev_t[mol_idx][i][aev_params.radial_length + iaev] = saev[iaev];
}
}
template <typename SpeciesT, typename DataT, int THREADS_PER_RIJ>
__global__ void cuRadialAEVs(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits>
species_t,
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> ShfR_t,
torch::PackedTensorAccessor32<DataT, 1, torch::RestrictPtrTraits> EtaR_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> aev_t,
PairDist<DataT> *d_Rij, AEVScalarParams<DataT, int> aev_params,
PairDist<DataT>* d_Rij,
AEVScalarParams<DataT, int> aev_params,
int nRadialRij) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = gidx / THREADS_PER_RIJ;
......@@ -290,104 +284,123 @@ __global__ void cuRadialAEVs(
DataT GmR = 0.25 * exp(-EtaR * (Rij - ShfR) * (Rij - ShfR)) * fc;
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr],
GmR);
atomicAdd(&aev_t[mol_idx][i][type_j * aev_params.radial_sublength + ishfr], GmR);
}
}
template <typename DataT>
void cubScan(const DataT *d_in, DataT *d_out, int num_items,
cudaStream_t stream) {
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out,
num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out,
num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
}
template <typename DataT, typename IndexT>
int cubEncode(const DataT *d_in, DataT *d_unique_out, IndexT *d_counts_out,
int num_items, int *d_num_runs_out, cudaStream_t stream) {
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
int cubEncode(
const DataT* d_in,
DataT* d_unique_out,
IndexT* d_counts_out,
int num_items,
int* d_num_runs_out,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in,
d_unique_out, d_counts_out, d_num_runs_out,
num_items, stream);
cub::DeviceRunLengthEncode::Encode(
d_temp_storage,
temp_storage_bytes,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
num_items,
stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run encoding
cub::DeviceRunLengthEncode::Encode(d_temp_storage, temp_storage_bytes, d_in,
d_unique_out, d_counts_out, d_num_runs_out,
num_items, stream);
cub::DeviceRunLengthEncode::Encode(
d_temp_storage,
temp_storage_bytes,
d_in,
d_unique_out,
d_counts_out,
d_num_runs_out,
num_items,
stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault,
stream);
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT, typename LambdaOpT>
int cubDeviceSelect(const DataT *d_in, DataT *d_out, int num_items,
int *d_num_selected_out, LambdaOpT select_op,
cudaStream_t stream) {
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
int cubDeviceSelect(
const DataT* d_in,
DataT* d_out,
int num_items,
int* d_num_selected_out,
LambdaOpT select_op,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out,
d_num_selected_out, num_items, select_op);
cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run selection
cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out,
d_num_selected_out, num_items, select_op, stream);
cub::DeviceSelect::If(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
d_num_selected_out,
num_items,
select_op,
stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int),
cudaMemcpyDefault, stream);
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT>
DataT cubMax(const DataT *d_in, int num_items, DataT *d_out,
cudaStream_t stream) {
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void *d_temp_storage = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out,
num_items, stream);
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run min-reduction
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out,
num_items, stream);
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
......@@ -396,42 +409,52 @@ DataT cubMax(const DataT *d_in, int num_items, DataT *d_out,
return maxVal;
}
void initConsts(AEVScalarParams<float> &aev_params, cudaStream_t stream) {
void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
int num_species = aev_params.num_species;
assert(num_species <= MAX_NSPECIES);
// precompute the aev offsets and load to constand memory
int *subaev_offsets = new int[num_species * num_species];
int* subaev_offsets = new int[num_species * num_species];
for (int t = 0; t < num_species; ++t) {
int offset = 0;
for (int s = 0; s < num_species; s++) {
if (t < num_species - s) {
subaev_offsets[s * num_species + s + t] =
aev_params.angular_sublength * (offset + t);
subaev_offsets[(s + t) * num_species + s] =
aev_params.angular_sublength * (offset + t);
subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t);
subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t);
}
offset += num_species - s;
}
}
cudaMemcpyToSymbolAsync(csubaev_offsets, subaev_offsets,
sizeof(int) * num_species * num_species, 0,
cudaMemcpyDefault, stream);
cudaMemcpyToSymbolAsync(
csubaev_offsets,
subaev_offsets,
sizeof(int) * num_species * num_species,
0,
cudaMemcpyDefault,
stream);
delete[] subaev_offsets;
}
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
double Rcr_, double Rca_, torch::Tensor EtaR_t,
torch::Tensor ShfR_t, torch::Tensor EtaA_t,
torch::Tensor Zeta_t, torch::Tensor ShfA_t,
torch::Tensor ShfZ_t, int64_t num_species_) {
TORCH_CHECK((species_t.dtype() == torch::kInt32) &&
(coordinates_t.dtype() == torch::kFloat32),
"Unsupported input type");
TORCH_CHECK(EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified "
"configuration");
torch::Tensor cuComputeAEV(
torch::Tensor coordinates_t,
torch::Tensor species_t,
double Rcr_,
double Rca_,
torch::Tensor EtaR_t,
torch::Tensor ShfR_t,
torch::Tensor EtaA_t,
torch::Tensor Zeta_t,
torch::Tensor ShfA_t,
torch::Tensor ShfZ_t,
int64_t num_species_) {
TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32),
"Unsupported input type");
TORCH_CHECK(
EtaR_t.size(0) == 1 || EtaA_t.size(0) == 1 || Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified "
"configuration");
ScalarRealT Rcr = Rcr_;
ScalarRealT Rca = Rca_;
......@@ -448,35 +471,29 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
aev_params.radial_sublength = EtaR_t.size(0) * ShfR_t.size(0);
aev_params.radial_length = aev_params.radial_sublength * num_species;
aev_params.angular_sublength =
EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0);
aev_params.angular_length =
aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
aev_params.angular_sublength = EtaA_t.size(0) * Zeta_t.size(0) * ShfA_t.size(0) * ShfZ_t.size(0);
aev_params.angular_length = aev_params.angular_sublength * (num_species * (num_species + 1) / 2);
int aev_length = aev_params.radial_length + aev_params.angular_length;
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length},
coordinates_t.options());
auto aev_t = torch::zeros({n_molecules, max_natoms_per_mol, aev_length}, coordinates_t.options());
if (species_t.numel() == 0) {
return aev_t;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto thrust_allocator =
THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto thrust_allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
auto &allocator = *c10::cuda::CUDACachingAllocator::get();
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// precompute the aev offsets and load to constand memory
initConsts(aev_params, stream);
// buffer to store all the pairwise distance (Rij)
auto total_natom_pairs =
n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto buffer_Rij =
allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
PairDist<float> *d_Rij = (PairDist<float> *)buffer_Rij.get();
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto buffer_Rij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
PairDist<float>* d_Rij = (PairDist<float>*)buffer_Rij.get();
// init all Rij to inf
PairDist<float> init;
......@@ -485,28 +502,31 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
// buffer to store all the pairwise distance that is needed for Radial AEV
// computation
auto buffer_radialRij =
allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
PairDist<float> *d_radialRij = (PairDist<float> *)buffer_radialRij.get();
auto buffer_radialRij = allocator.allocate(sizeof(PairDist<float>) * total_natom_pairs);
PairDist<float>* d_radialRij = (PairDist<float>*)buffer_radialRij.get();
auto buffer_count = allocator.allocate(sizeof(int));
int *d_count_out = (int *)buffer_count.get();
int* d_count_out = (int*)buffer_count.get();
const int block_size = 64;
dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3,
stream>>>(
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij, max_natoms_per_mol);
d_Rij,
max_natoms_per_mol);
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <=
// Rcr
int nRadialRij = cubDeviceSelect(
d_Rij, d_radialRij, total_natom_pairs, d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; }, stream);
d_Rij,
d_radialRij,
total_natom_pairs,
d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rcr; },
stream);
int nblocks = (nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs<int, float, 8><<<nblocks, block_size, 0, stream>>>(
......@@ -514,40 +534,41 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
ShfR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
EtaR_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, aev_params, nRadialRij);
d_radialRij,
aev_params,
nRadialRij);
// reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation
PairDist<float> *d_angularRij = d_Rij;
PairDist<float>* d_angularRij = d_Rij;
// Extract Rijs that is needed for AngularAEV comptuation i.e. all the Rij
// <= Rca
int nAngularRij = cubDeviceSelect(
d_radialRij, d_angularRij, nRadialRij, d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; }, stream);
d_radialRij,
d_angularRij,
nRadialRij,
d_count_out,
[=] __device__(const PairDist<float> d) { return d.Rij <= Rca; },
stream);
auto buffer_centralAtom =
allocator.allocate(sizeof(PairDist<float>) * nAngularRij);
PairDist<float> *d_centralAtom = (PairDist<float> *)buffer_centralAtom.get();
auto buffer_centralAtom = allocator.allocate(sizeof(PairDist<float>) * nAngularRij);
PairDist<float>* d_centralAtom = (PairDist<float>*)buffer_centralAtom.get();
auto buffer_numPairsPerCenterAtom =
allocator.allocate(sizeof(int) * nAngularRij);
int *d_numPairsPerCenterAtom = (int *)buffer_numPairsPerCenterAtom.get();
auto buffer_numPairsPerCenterAtom = allocator.allocate(sizeof(int) * nAngularRij);
int* d_numPairsPerCenterAtom = (int*)buffer_numPairsPerCenterAtom.get();
// group by center atom
int ncenter_atoms =
cubEncode(d_angularRij, d_centralAtom, d_numPairsPerCenterAtom,
nAngularRij, d_count_out, stream);
int ncenter_atoms = cubEncode(
d_angularRij, d_centralAtom, d_numPairsPerCenterAtom, nAngularRij, d_count_out, stream);
auto buffer_centerAtomStartIdx =
allocator.allocate(sizeof(int) * ncenter_atoms);
int *d_centerAtomStartIdx = (int *)buffer_centerAtomStartIdx.get();
auto buffer_centerAtomStartIdx = allocator.allocate(sizeof(int) * ncenter_atoms);
int* d_centerAtomStartIdx = (int*)buffer_centerAtomStartIdx.get();
cubScan(d_numPairsPerCenterAtom, d_centerAtomStartIdx, ncenter_atoms, stream);
{
const int nthreads_per_catom = 32;
const int nblocks_angAEV =
(ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
const int nblocks_angAEV = (ncenter_atoms * nthreads_per_catom + block_size - 1) / block_size;
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sm_aev = sizeof(float) * align<4>(aev_params.angular_length);
int sxyz = sizeof(float) * max_nbrs * 3;
......@@ -558,15 +579,15 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
return (sm_aev + sxyz + sRij + sfc + sj) * ncatom_per_tpb;
};
int maxNbrsPerCenterAtom =
cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
int maxNbrsPerCenterAtom = cubMax(d_numPairsPerCenterAtom, ncenter_atoms, d_count_out, stream);
int maxnbrs_per_atom_aligned = align<4>(maxNbrsPerCenterAtom);
cuAngularAEVs<<<nblocks_angAEV, block_size,
smem_size(maxnbrs_per_atom_aligned,
block_size / nthreads_per_catom),
stream>>>(
cuAngularAEVs<<<
nblocks_angAEV,
block_size,
smem_size(maxnbrs_per_atom_aligned, block_size / nthreads_per_catom),
stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
ShfA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
......@@ -574,13 +595,20 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
EtaA_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
Zeta_t.packed_accessor32<float, 1, torch::RestrictPtrTraits>(),
aev_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_angularRij, d_centralAtom, d_numPairsPerCenterAtom,
d_centerAtomStartIdx, aev_params, maxnbrs_per_atom_aligned,
align<4>(aev_params.angular_length), ncenter_atoms);
d_angularRij,
d_centralAtom,
d_numPairsPerCenterAtom,
d_centerAtomStartIdx,
aev_params,
maxnbrs_per_atom_aligned,
align<4>(aev_params.angular_length),
ncenter_atoms);
}
return aev_t;
}
TORCH_LIBRARY(cuaev, m) { m.def("cuComputeAEV", &cuComputeAEV<float>); }
TORCH_LIBRARY(cuaev, m) {
m.def("cuComputeAEV", &cuComputeAEV<float>);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment