Unverified Commit 6a3dd807 authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

Fix CUAEV slower than pyaev as molecule size increase to 1000 atoms or more (#565)



* fix

* format

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 910cca89
...@@ -127,6 +127,48 @@ __global__ void pairwiseDistance( ...@@ -127,6 +127,48 @@ __global__ void pairwiseDistance(
} }
} }
template <typename SpeciesT, typename DataT, typename IndexT = int>
__global__ void pairwiseDistanceSingleMolecule(
torch::PackedTensorAccessor32<SpeciesT, 2, torch::RestrictPtrTraits> species_t,
torch::PackedTensorAccessor32<DataT, 3, torch::RestrictPtrTraits> pos_t,
PairDist<DataT>* d_Rij,
IndexT max_natoms_per_mol) {
constexpr int mol_idx = 0;
int natom_pairs = max_natoms_per_mol * max_natoms_per_mol;
int i = blockIdx.x * blockDim.x + threadIdx.x;
int j = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= max_natoms_per_mol || j >= max_natoms_per_mol)
return;
SpeciesT type_i = species_t[mol_idx][i];
DataT xi = pos_t[mol_idx][i][0];
DataT yi = pos_t[mol_idx][i][1];
DataT zi = pos_t[mol_idx][i][2];
SpeciesT type_j = species_t[mol_idx][j];
DataT xj = pos_t[mol_idx][j][0];
DataT yj = pos_t[mol_idx][j][1];
DataT zj = pos_t[mol_idx][j][2];
DataT delx = xj - xi;
DataT dely = yj - yi;
DataT delz = zj - zi;
DataT Rsq = delx * delx + dely * dely + delz * delz;
if (type_i != -1 && type_j != -1 && i != j) {
DataT Rij = sqrt(Rsq);
PairDist<DataT> d;
d.Rij = Rij;
d.midx = mol_idx;
d.i = i;
d.j = j;
d_Rij[mol_idx * natom_pairs + i * max_natoms_per_mol + j] = d;
}
}
// every block compute blocksize RIJ's gradient by column major, to avoid atomicAdd waiting // every block compute blocksize RIJ's gradient by column major, to avoid atomicAdd waiting
template <typename DataT, typename IndexT = int> template <typename DataT, typename IndexT = int>
__global__ void pairwiseDistance_backward( __global__ void pairwiseDistance_backward(
...@@ -857,13 +899,27 @@ Result cuaev_forward( ...@@ -857,13 +899,27 @@ Result cuaev_forward(
const int block_size = 64; const int block_size = 64;
dim3 block(8, 8, 1); dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule if (n_molecules == 1) {
// maximum 4096 atoms, which needs 49152 byte (48 kb) of shared memory int tileWidth = 32;
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>( int tilesPerRow = (max_natoms_per_mol + tileWidth - 1) / tileWidth;
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(), dim3 block(tileWidth, tileWidth, 1);
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), dim3 grid(tilesPerRow, tilesPerRow, 1);
d_Rij, pairwiseDistanceSingleMolecule<<<grid, block, 0, stream>>>(
max_natoms_per_mol); species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
} else {
dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule
// maximum 4096 atoms, which needs 49152 byte (48 kb) of shared memory
// TODO: the kernel is not optimized for batched huge molecule (max_natoms_per_mol > 1000)
pairwiseDistance<<<n_molecules, block, sizeof(float) * max_natoms_per_mol * 3, stream>>>(
species_t.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
}
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr // Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr
int nRadialRij = cubDeviceSelect( int nRadialRij = cubDeviceSelect(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment