Unverified Commit 33fe85d9 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove constant memory usage (#572)



* save

* Remove constant memory usage

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* format
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>
parent 813f6e61
......@@ -144,8 +144,19 @@ struct AEVScalarParams {
}
};
#define MAX_NSPECIES 10
__constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES];
// fetch from the following matrix
// [[ 0, 1, 2, 3, 4],
// [ 1, 5, 6, 7, 8],
// [ 2, 6, 9, 10, 11],
// [ 3, 7, 10, 12, 13],
// [ 4, 8, 11, 13, 14]]
constexpr int csubaev_offsets(int i, int j, int n) {
int larger = std::max(i, j);
int smaller = std::min(i, j);
int starting = smaller * (2 * n - smaller + 1) / 2; // n + (n - 1) + ... + (n - smaller + 1)
int offset = larger - smaller;
return starting + offset;
}
template <typename DataT>
struct PairDist {
......@@ -438,7 +449,7 @@ __global__ void cuAngularAEVs(
DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k];
IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta];
......@@ -639,7 +650,7 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k];
IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta];
......@@ -941,26 +952,6 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream
return maxVal;
}
void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
int num_species = aev_params.num_species;
assert(num_species <= MAX_NSPECIES);
// precompute the aev offsets and load to constand memory
int* subaev_offsets = new int[num_species * num_species];
for (int t = 0; t < num_species; ++t) {
int offset = 0;
for (int s = 0; s < num_species; s++) {
if (t < num_species - s) {
subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t);
subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t);
}
offset += num_species - s;
}
}
cudaMemcpyToSymbolAsync(
csubaev_offsets, subaev_offsets, sizeof(int) * num_species * num_species, 0, cudaMemcpyDefault, stream);
delete[] subaev_offsets;
}
struct Result {
Tensor aev_t;
AEVScalarParams<float> aev_params;
......@@ -1030,9 +1021,6 @@ Result cuaev_forward(
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// precompute the aev offsets and load to constand memory
initConsts(aev_params, stream);
// buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment