Unverified Commit 258e6c36 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

cuaev def_pickle (#602)

parent 11f44927
...@@ -34,6 +34,18 @@ class TestCUAEVNoGPU(TestCase): ...@@ -34,6 +34,18 @@ class TestCUAEVNoGPU(TestCase):
coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5) coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5)
self.assertIn("cuaev::run", str(s.graph_for((species, coordinates)))) self.assertIn("cuaev::run", str(s.graph_for((species, coordinates))))
def testPickle(self):
path = os.path.dirname(os.path.realpath(__file__))
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts, use_cuda_extension=True)
tmpfile = '/tmp/cuaev.pkl'
with open(tmpfile, 'wb') as file:
pickle.dump(aev_computer, file)
with open(tmpfile, 'rb') as file:
aev_computer = pickle.load(file)
os.remove(tmpfile)
@skipIfNoGPU @skipIfNoGPU
@skipIfNoCUAEV @skipIfNoCUAEV
...@@ -140,6 +152,21 @@ class TestCUAEV(TestCase): ...@@ -140,6 +152,21 @@ class TestCUAEV(TestCase):
self.testSimpleDoubleBackward_2() self.testSimpleDoubleBackward_2()
self.setUp(device='cuda:0') self.setUp(device='cuda:0')
def testPickleCorrectness(self):
ref_aev_computer = self.cuaev_computer
tmpfile = '/tmp/cuaev.pkl'
with open(tmpfile, 'wb') as file:
pickle.dump(ref_aev_computer, file)
with open(tmpfile, 'rb') as file:
test_aev_computer = pickle.load(file)
os.remove(tmpfile)
coordinates = torch.rand([2, 50, 3], device=self.device) * 5
species = torch.randint(-1, 3, (2, 50), device=self.device)
_, ref_aev = ref_aev_computer((species, coordinates))
_, test_aev = test_aev_computer((species, coordinates))
self.assertEqual(ref_aev, test_aev)
def testSimpleBackward(self): def testSimpleBackward(self):
coordinates = torch.tensor([ coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679], [[0.03192167, 0.00638559, 0.01301679],
......
...@@ -150,7 +150,35 @@ Tensor run_autograd( ...@@ -150,7 +150,35 @@ Tensor run_autograd(
TORCH_LIBRARY(cuaev, m) { TORCH_LIBRARY(cuaev, m) {
m.class_<CuaevComputer>("CuaevComputer") m.class_<CuaevComputer>("CuaevComputer")
.def(torch::init<double, double, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int64_t>()); .def(torch::init<double, double, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int64_t>())
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CuaevComputer>& self) -> std::vector<Tensor> {
std::vector<Tensor> state;
state.push_back(torch::tensor(self->aev_params.Rcr));
state.push_back(torch::tensor(self->aev_params.Rca));
state.push_back(self->aev_params.EtaR_t);
state.push_back(self->aev_params.ShfR_t);
state.push_back(self->aev_params.EtaA_t);
state.push_back(self->aev_params.Zeta_t);
state.push_back(self->aev_params.ShfA_t);
state.push_back(self->aev_params.ShfZ_t);
state.push_back(torch::tensor(self->aev_params.num_species));
return state;
},
// __setstate__
[](std::vector<Tensor> state) -> c10::intrusive_ptr<CuaevComputer> {
return c10::make_intrusive<CuaevComputer>(
state[0].item<double>(),
state[1].item<double>(),
state[2],
state[3],
state[4],
state[5],
state[6],
state[7],
state[8].item<int64_t>());
});
m.def("run", run_only_forward); m.def("run", run_only_forward);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment