Unverified Commit bf771af0 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

Fix AEV and CUAEV Multi-GPU Device Bug (#597)

* Fix AEV and CUAEV GPU Device Bug

* format

* apply review suggestion

* apply review suggestion
parent ef834586
...@@ -8,8 +8,8 @@ from torchani.testing import TestCase, make_tensor ...@@ -8,8 +8,8 @@ from torchani.testing import TestCase, make_tensor
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(), skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(), 'There is no device to run this test')
'There is no device to run this test') skipIfNoMultiGPU = unittest.skipIf(not torch.cuda.device_count() >= 2, 'There is not enough GPU devices to run this test')
skipIfNoCUAEV = unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed") skipIfNoCUAEV = unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")
...@@ -39,9 +39,9 @@ class TestCUAEVNoGPU(TestCase): ...@@ -39,9 +39,9 @@ class TestCUAEVNoGPU(TestCase):
@skipIfNoCUAEV @skipIfNoCUAEV
class TestCUAEV(TestCase): class TestCUAEV(TestCase):
def setUp(self): def setUp(self, device='cuda:0'):
self.tolerance = 5e-5 self.tolerance = 5e-5
self.device = 'cuda' self.device = device
Rcr = 5.2000e+00 Rcr = 5.2000e+00
Rca = 3.5000e+00 Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=self.device) EtaR = torch.tensor([1.6000000e+01], device=self.device)
...@@ -131,6 +131,15 @@ class TestCUAEV(TestCase): ...@@ -131,6 +131,15 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates)) _, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev) self.assertEqual(cu_aev, aev)
@skipIfNoMultiGPU
def testMultiGPU(self):
self.setUp(device='cuda:1')
self.testSimple()
self.testSimpleBackward()
self.testSimpleDoubleBackward_1()
self.testSimpleDoubleBackward_2()
self.setUp(device='cuda:0')
def testSimpleBackward(self): def testSimpleBackward(self):
coordinates = torch.tensor([ coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679], [[0.03192167, 0.00638559, 0.01301679],
......
#include <aev.h> #include <aev.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cuaev_cub.cuh> #include <cuaev_cub.cuh>
#include <vector>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <THC/THC.h> #include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <THC/THCThrustAllocator.cuh> #include <THC/THCThrustAllocator.cuh>
#include <vector>
#define PI 3.141592653589793 #define PI 3.141592653589793
using torch::Tensor; using torch::Tensor;
...@@ -742,9 +745,13 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -742,9 +745,13 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
TORCH_CHECK( TORCH_CHECK(
(species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type"); (species_t.dtype() == torch::kInt32) && (coordinates_t.dtype() == torch::kFloat32), "Unsupported input type");
TORCH_CHECK( TORCH_CHECK(
aev_params.EtaR_t.size(0) == 1 || aev_params.EtaA_t.size(0) == 1 || aev_params.Zeta_t.size(0) == 1, aev_params.EtaR_t.size(0) == 1 && aev_params.EtaA_t.size(0) == 1 && aev_params.Zeta_t.size(0) == 1,
"cuda extension is currently not supported for the specified " "cuda extension is currently not supported for the specified "
"configuration"); "configuration");
TORCH_CHECK(
coordinates_t.device() == species_t.device() && coordinates_t.device() == aev_params.EtaR_t.device() &&
coordinates_t.device() == aev_params.EtaA_t.device(),
"coordinates, species, and aev_params should be on the same device");
float Rcr = aev_params.Rcr; float Rcr = aev_params.Rcr;
float Rca = aev_params.Rca; float Rca = aev_params.Rca;
...@@ -759,7 +766,8 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -759,7 +766,8 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_t, Tensor(), Tensor(), Tensor(), 0, 0, 0, Tensor(), Tensor(), Tensor(), 0, 0, 0, coordinates_t, species_t}; aev_t, Tensor(), Tensor(), Tensor(), 0, 0, 0, Tensor(), Tensor(), Tensor(), 0, 0, 0, coordinates_t, species_t};
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// buffer to store all the pairwise distance (Rij) // buffer to store all the pairwise distance (Rij)
...@@ -790,6 +798,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -790,6 +798,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij, d_Rij,
max_natoms_per_mol); max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else { } else {
dim3 block(8, 8, 1); dim3 block(8, 8, 1);
// Compute pairwise distance (Rij) for all atom pairs in a molecule // Compute pairwise distance (Rij) for all atom pairs in a molecule
...@@ -800,6 +809,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -800,6 +809,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), coordinates_t.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_Rij, d_Rij,
max_natoms_per_mol); max_natoms_per_mol);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
// Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr // Extract Rijs that is needed for RadialAEV comptuation i.e. all the Rij <= Rcr
...@@ -822,6 +832,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -822,6 +832,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
aev_params.radial_length, aev_params.radial_length,
aev_params.radial_sublength, aev_params.radial_sublength,
nRadialRij); nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// reuse buffer allocated for all Rij // reuse buffer allocated for all Rij
// d_angularRij will store all the Rij required in Angular AEV computation // d_angularRij will store all the Rij required in Angular AEV computation
...@@ -890,6 +901,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -890,6 +901,7 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
maxnbrs_per_atom_aligned, maxnbrs_per_atom_aligned,
angular_length_aligned, angular_length_aligned,
ncenter_atoms); ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return { return {
aev_t, aev_t,
...@@ -917,7 +929,8 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para ...@@ -917,7 +929,8 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
const int n_molecules = coordinates_t.size(0); const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1); const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3] auto grad_coord = torch::zeros(coordinates_t.sizes(), coordinates_t.options().requires_grad(false)); // [2, 5, 3]
...@@ -943,6 +956,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para ...@@ -943,6 +956,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
aev_params.radial_length, aev_params.radial_length,
aev_params.radial_sublength, aev_params.radial_sublength,
result.nRadialRij); result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// For best result, block_size should match average molecule size (no padding) to avoid atomicAdd // For best result, block_size should match average molecule size (no padding) to avoid atomicAdd
nblocks = (result.nRadialRij + block_size - 1) / block_size; nblocks = (result.nRadialRij + block_size - 1) / block_size;
...@@ -952,6 +966,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para ...@@ -952,6 +966,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_coord.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
result.nRadialRij); result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
...@@ -991,6 +1006,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para ...@@ -991,6 +1006,7 @@ Tensor cuaev_backward(const Tensor& grad_output, const AEVScalarParams& aev_para
result.maxnbrs_per_atom_aligned, result.maxnbrs_per_atom_aligned,
result.angular_length_aligned, result.angular_length_aligned,
result.ncenter_atoms); result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_coord; return grad_coord;
} }
...@@ -1002,7 +1018,8 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae ...@@ -1002,7 +1018,8 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
const int n_molecules = coordinates_t.size(0); const int n_molecules = coordinates_t.size(0);
const int max_natoms_per_mol = coordinates_t.size(1); const int max_natoms_per_mol = coordinates_t.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAGuard device_guard(coordinates_t.device().index());
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int aev_length = aev_params.radial_length + aev_params.angular_length; int aev_length = aev_params.radial_length + aev_params.angular_length;
...@@ -1027,6 +1044,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae ...@@ -1027,6 +1044,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(), grad_force.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
d_radialRij, d_radialRij,
result.nRadialRij); result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size; nblocks = (result.nRadialRij * 8 + block_size - 1) / block_size;
cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>( cuRadialAEVs_backward_or_doublebackward<true, int, float, 8><<<nblocks, block_size, 0, stream>>>(
...@@ -1040,6 +1058,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae ...@@ -1040,6 +1058,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
aev_params.radial_length, aev_params.radial_length,
aev_params.radial_sublength, aev_params.radial_sublength,
result.nRadialRij); result.nRadialRij);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) { auto smem_size = [&aev_params](int max_nbrs, int ncatom_per_tpb) {
int sxyz = sizeof(float) * max_nbrs * 3; int sxyz = sizeof(float) * max_nbrs * 3;
...@@ -1078,6 +1097,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae ...@@ -1078,6 +1097,7 @@ Tensor cuaev_double_backward(const Tensor& grad_force, const AEVScalarParams& ae
result.maxnbrs_per_atom_aligned, result.maxnbrs_per_atom_aligned,
result.angular_length_aligned, result.angular_length_aligned,
result.ncenter_atoms); result.ncenter_atoms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_grad_aev; return grad_grad_aev;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment