Unverified Commit bf771af0 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

Fix AEV and CUAEV Multi-GPU Device Bug (#597)

* Fix AEV and CUAEV GPU Device Bug

* format

* apply review suggestion

* apply review suggestion
parent ef834586
......@@ -8,8 +8,8 @@ from torchani.testing import TestCase, make_tensor
path = os.path.dirname(os.path.realpath(__file__))
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
'There is no device to run this test')
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(), 'There is no device to run this test')
skipIfNoMultiGPU = unittest.skipIf(not torch.cuda.device_count() >= 2, 'There is not enough GPU devices to run this test')
skipIfNoCUAEV = unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")
......@@ -39,9 +39,9 @@ class TestCUAEVNoGPU(TestCase):
@skipIfNoCUAEV
class TestCUAEV(TestCase):
def setUp(self):
def setUp(self, device='cuda:0'):
self.tolerance = 5e-5
self.device = 'cuda'
self.device = device
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=self.device)
......@@ -131,6 +131,15 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev)
@skipIfNoMultiGPU
def testMultiGPU(self):
self.setUp(device='cuda:1')
self.testSimple()
self.testSimpleBackward()
self.testSimpleDoubleBackward_1()
self.testSimpleDoubleBackward_2()
self.setUp(device='cuda:0')
def testSimpleBackward(self):
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
......
#include <aev.h>
#include <torch/extension.h>
#include <cuaev_cub.cuh>
#include <vector>
#include <ATen/Context.h>
#include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <THC/THCThrustAllocator.cuh>
#include <vector>
#define PI 3.141592653589793
using torch::Tensor;
......@@ -742,9 +745,13 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
TORCH_CHECK(
aev_params.EtaR_t.size(0) == 1 || aev_params.EtaA_t.size(0) == 1 || aev_params.Zeta_t.size(0) == 1,
aev_params.EtaR_t.size(0) == 1 && aev_params.EtaA_t.size(0) == 1 && aev_params.Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified "
"configuration");
TORCH_CHECK(
coordinates_t.device() == species_t.device() && coordinates_t.device() == aev_params.EtaR_t.device() &&
coordinates_t.device() == aev_params.EtaA_t.device(),
"coordinates, species, and aev_params should be on the same device");
float Rcr = aev_params.Rcr;
float Rca = aev_params.Rca;
......@@ -759,7 +766,8 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_t, Tensor(), Tensor(), Tensor(), 0, 0, 0, Tensor(), Tensor(), Tensor(), 0, 0, 0, coordinates_t, species_t};
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// buffer to store all the pairwise distance (Rij)
......@@ -790,6 +798,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule
......@@ -800,6 +809,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij,
max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr
......@@ -822,6 +832,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_params.radial_length,
aev_params.radial_sublength,
nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation
......@@ -890,6 +901,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
maxnbrs_per_atom_aligned,
angular_length_aligned,
ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {
aev_t,
......@@ -917,7 +929,8 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3]
......@@ -943,6 +956,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (result.nRadialRij + block_size - 1) / block_size;
......@@ -952,6 +966,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
......@@ -991,6 +1006,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_coord;
}
......@@ -1002,7 +1018,8 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int aev_length = aev_params.radial_length + aev_params.angular_length;
......@@ -1027,6 +1044,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>(
......@@ -1040,6 +1058,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
aev_params.radial_length,
aev_params.radial_sublength,
result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3;
......@@ -1078,6 +1097,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
result.maxnbrs_per_atom_aligned,
result.angular_length_aligned,
result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_grad_aev;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment