Unverified Commit 33fe85d9 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove constant memory usage (#572)



* save

* Remove constant memory usage

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* Update torchani/cuaev/aev.cu
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>

* format
Co-authored-by: default avatarJinze Xue <yueyericardo@gmail.com>
parent 813f6e61
...@@ -144,8 +144,19 @@ struct AEVScalarParams { ...@@ -144,8 +144,19 @@ struct AEVScalarParams {
} }
}; };
#define MAX_NSPECIES 10 // fetch from the following matrix
__constant__ int csubaev_offsets[MAX_NSPECIES * MAX_NSPECIES]; // [[ 0, 1, 2, 3, 4],
// [ 1, 5, 6, 7, 8],
// [ 2, 6, 9, 10, 11],
// [ 3, 7, 10, 12, 13],
// [ 4, 8, 11, 13, 14]]
constexpr int csubaev_offsets(int i, int j, int n) {
int larger = std::max(i, j);
int smaller = std::min(i, j);
int starting = smaller * (2 * n - smaller + 1) / 2; // n + (n - 1) + ... + (n - smaller + 1)
int offset = larger - smaller;
return starting + offset;
}
template <typename DataT> template <typename DataT>
struct PairDist { struct PairDist {
...@@ -438,7 +449,7 @@ __global__ void cuAngularAEVs( ...@@ -438,7 +449,7 @@ __global__ void cuAngularAEVs(
DataT Rijk = (Rij + Rik) / 2; DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k]; IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -639,7 +650,7 @@ __global__ void cuAngularAEVs_backward_or_doublebackward( ...@@ -639,7 +650,7 @@ __global__ void cuAngularAEVs_backward_or_doublebackward(
DataT Rijk = (Rij + Rik) / 2; DataT Rijk = (Rij + Rik) / 2;
DataT fc_ijk = fc_ij * fc_ik; DataT fc_ijk = fc_ij * fc_ik;
IndexT subaev_offset = csubaev_offsets[type_j * num_species + type_k]; IndexT subaev_offset = aev_params.angular_sublength * csubaev_offsets(type_j, type_k, num_species);
for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) { for (int itheta = tile.x; itheta < nShfZ; itheta += TILEX) {
DataT ShfZ = ShfZ_t[itheta]; DataT ShfZ = ShfZ_t[itheta];
...@@ -941,26 +952,6 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream ...@@ -941,26 +952,6 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream
return maxVal; return maxVal;
} }
void initConsts(AEVScalarParams<float>& aev_params, cudaStream_t stream) {
int num_species = aev_params.num_species;
assert(num_species <= MAX_NSPECIES);
// precompute the aev offsets and load to constand memory
int* subaev_offsets = new int[num_species * num_species];
for (int t = 0; t < num_species; ++t) {
int offset = 0;
for (int s = 0; s < num_species; s++) {
if (t < num_species - s) {
subaev_offsets[s * num_species + s + t] = aev_params.angular_sublength * (offset + t);
subaev_offsets[(s + t) * num_species + s] = aev_params.angular_sublength * (offset + t);
}
offset += num_species - s;
}
}
cudaMemcpyToSymbolAsync(
csubaev_offsets, subaev_offsets, sizeof(int) * num_species * num_species, 0, cudaMemcpyDefault, stream);
delete[] subaev_offsets;
}
struct Result { struct Result {
Tensor aev_t; Tensor aev_t;
AEVScalarParams<float> aev_params; AEVScalarParams<float> aev_params;
...@@ -1030,9 +1021,6 @@ Result cuaev_forward( ...@@ -1030,9 +1021,6 @@ Result cuaev_forward(
auto policy = thrust::cuda::par(thrust_allocator).on(stream); auto policy = thrust::cuda::par(thrust_allocator).on(stream);
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// precompute the aev offsets and load to constand memory
initConsts(aev_params, stream);
// buffer to store all the pairwise distance (Rij) // buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol; auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device()); auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment