Unverified Commit 53fc6392 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Torch] Fix inter-process communication for PyTorch (#1858)

* fix

* add comment
parent 62077ef3
...@@ -25,7 +25,12 @@ class HeteroGraphIndex(ObjectBase): ...@@ -25,7 +25,12 @@ class HeteroGraphIndex(ObjectBase):
return obj return obj
def __getstate__(self): def __getstate__(self):
return _CAPI_DGLHeteroPickle(self) """Issue: https://github.com/pytorch/pytorch/issues/32351
Need to set the tensor created in the __getstate__ function
as object attribute to avoid potential bugs
"""
self._pk_state = _CAPI_DGLHeteroPickle(self)
return self._pk_state
def __setstate__(self, state): def __setstate__(self, state):
self._cache = {} self._cache = {}
...@@ -1230,8 +1235,12 @@ class HeteroPickleStates(ObjectBase): ...@@ -1230,8 +1235,12 @@ class HeteroPickleStates(ObjectBase):
return [arr_func(i) for i in range(num_arr)] return [arr_func(i) for i in range(num_arr)]
def __getstate__(self): def __getstate__(self):
arrays = [F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays] """Issue: https://github.com/pytorch/pytorch/issues/32351
return self.version, self.meta, arrays Need to set the tensor created in the __getstate__ function
as object attribute to avoid potential bugs
"""
self._pk_arrays = [F.zerocopy_from_dgl_ndarray(arr) for arr in self.arrays]
return self.version, self.meta, self._pk_arrays
def __setstate__(self, state): def __setstate__(self, state):
if isinstance(state[0], int): if isinstance(state[0], int):
...@@ -1244,4 +1253,5 @@ class HeteroPickleStates(ObjectBase): ...@@ -1244,4 +1253,5 @@ class HeteroPickleStates(ObjectBase):
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type) num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) _CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs)
_init_api("dgl.heterograph_index") _init_api("dgl.heterograph_index")
import dgl
import torch as th
import torch.multiprocessing as mp
import os
import unittest
def sub_ipc(g):
print(g)
return g
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_torch_ipc():
g = dgl.graph([(0, 1), (1, 2), (2, 3)])
ctx = mp.get_context("spawn")
p = ctx.Process(target=sub_ipc, args=(g, ))
p.start()
p.join()
if __name__ == "__main__":
test_torch_ipc()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment