Unverified Commit adb3a7c1 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] record/restore pin status when pickle/unpickle (#3914)

* [BugFix] record/restore pin status when pickle/unpickle

* disable test on TF

* set version as expected

* unpin memory in test
parent ed15b471
...@@ -111,6 +111,11 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -111,6 +111,11 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DLContext Context() const = 0; virtual DLContext Context() const = 0;
/*!
* \brief Pin graph.
*/
virtual void PinMemory_() = 0;
/*! /*!
* \brief Check if this graph is pinned. * \brief Check if this graph is pinned.
*/ */
......
...@@ -1356,10 +1356,10 @@ class HeteroPickleStates(ObjectBase): ...@@ -1356,10 +1356,10 @@ class HeteroPickleStates(ObjectBase):
def __setstate__(self, state): def __setstate__(self, state):
if isinstance(state[0], int): if isinstance(state[0], int):
_, meta, arrays = state version, meta, arrays = state
arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays] arrays = [F.zerocopy_to_dgl_ndarray(arr) for arr in arrays]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLCreateHeteroPickleStates, meta, arrays) _CAPI_DGLCreateHeteroPickleStates, version, meta, arrays)
else: else:
metagraph, num_nodes_per_type, adjs = state metagraph, num_nodes_per_type, adjs = state
num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type) num_nodes_per_type = F.zerocopy_to_dgl_ndarray(num_nodes_per_type)
...@@ -1367,9 +1367,9 @@ class HeteroPickleStates(ObjectBase): ...@@ -1367,9 +1367,9 @@ class HeteroPickleStates(ObjectBase):
_CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs) _CAPI_DGLCreateHeteroPickleStatesOld, metagraph, num_nodes_per_type, adjs)
def _forking_rebuild(pk_state): def _forking_rebuild(pk_state):
meta, arrays = pk_state version, meta, arrays = pk_state
arrays = [F.to_dgl_nd(arr) for arr in arrays] arrays = [F.to_dgl_nd(arr) for arr in arrays]
states = _CAPI_DGLCreateHeteroPickleStates(meta, arrays) states = _CAPI_DGLCreateHeteroPickleStates(version, meta, arrays)
graph_index = _CAPI_DGLHeteroForkingUnpickle(states) graph_index = _CAPI_DGLHeteroForkingUnpickle(states)
graph_index._forking_pk_state = pk_state graph_index._forking_pk_state = pk_state
return graph_index return graph_index
...@@ -1391,7 +1391,7 @@ def _forking_reduce(graph_index): ...@@ -1391,7 +1391,7 @@ def _forking_reduce(graph_index):
# the tensors as an attribute of the original graph index object. Otherwise # the tensors as an attribute of the original graph index object. Otherwise
# PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to # PyTorch will throw weird errors like bad value(s) in fds_to_keep or unable to
# resize file. # resize file.
graph_index._forking_pk_state = (states.meta, arrays) graph_index._forking_pk_state = (states.version, states.meta, arrays)
return _forking_rebuild, (graph_index._forking_pk_state,) return _forking_rebuild, (graph_index._forking_pk_state,)
......
...@@ -240,7 +240,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -240,7 +240,7 @@ class HeteroGraph : public BaseHeteroGraph {
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
void PinMemory_(); void PinMemory_() override;
/*! /*!
* \brief Unpin all relation graphs of the current graph. * \brief Unpin all relation graphs of the current graph.
......
...@@ -18,10 +18,12 @@ namespace dgl { ...@@ -18,10 +18,12 @@ namespace dgl {
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates states; HeteroPickleStates states;
states.version = 2;
dmlc::MemoryStringStream ofs(&states.meta); dmlc::MemoryStringStream ofs(&states.meta);
dmlc::Stream *strm = &ofs; dmlc::Stream *strm = &ofs;
strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph())); strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType()); strm->Write(graph->NumVerticesPerType());
strm->Write(graph->IsPinned());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE); SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
switch (fmt) { switch (fmt) {
...@@ -53,10 +55,12 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -53,10 +55,12 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) { HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
HeteroPickleStates states; HeteroPickleStates states;
states.version = 2;
dmlc::MemoryStringStream ofs(&states.meta); dmlc::MemoryStringStream ofs(&states.meta);
dmlc::Stream *strm = &ofs; dmlc::Stream *strm = &ofs;
strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph())); strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType()); strm->Write(graph->NumVerticesPerType());
strm->Write(graph->IsPinned());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto created_formats = graph->GetCreatedFormats(); auto created_formats = graph->GetCreatedFormats();
auto allowed_formats = graph->GetAllowedFormats(); auto allowed_formats = graph->GetAllowedFormats();
...@@ -97,6 +101,10 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -97,6 +101,10 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
std::vector<int64_t> num_nodes_per_type; std::vector<int64_t> num_nodes_per_type;
CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type"; CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
bool is_pinned = false;
if (states.version > 1) {
CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
}
auto array_itr = states.arrays.begin(); auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
...@@ -141,7 +149,11 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -141,7 +149,11 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
} }
relgraphs[etype] = relgraph; relgraphs[etype] = relgraph;
} }
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
if (is_pinned) {
graph->PinMemory_();
}
return graph;
} }
// For backward compatibility // For backward compatibility
...@@ -183,6 +195,10 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -183,6 +195,10 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
std::vector<int64_t> num_nodes_per_type; std::vector<int64_t> num_nodes_per_type;
CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type"; CHECK(strm->Read(&num_nodes_per_type)) << "Invalid num_nodes_per_type";
bool is_pinned = false;
if (states.version > 1) {
CHECK(strm->Read(&is_pinned)) << "Invalid flag 'is_pinned'";
}
auto array_itr = states.arrays.begin(); auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
...@@ -234,7 +250,11 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -234,7 +250,11 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
relgraphs[etype] = UnitGraph::CreateUnitGraphFrom( relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats); num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
} }
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); auto graph = CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
if (is_pinned) {
graph->PinMemory_();
}
return graph;
} }
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
...@@ -266,10 +286,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum") ...@@ -266,10 +286,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string meta = args[0]; const int version = args[0];
const List<Value> arrays = args[1]; std::string meta = args[1];
const List<Value> arrays = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->version = 1; st->version = version == 0 ? 1 : version;
st->meta = meta; st->meta = meta;
st->arrays.reserve(arrays.size()); st->arrays.reserve(arrays.size());
for (const auto& ref : arrays) { for (const auto& ref : arrays) {
...@@ -303,10 +324,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") ...@@ -303,10 +324,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
graph = HeteroUnpickleOld(*ref.sptr()); graph = HeteroUnpickleOld(*ref.sptr());
break; break;
case 1: case 1:
case 2:
graph = HeteroUnpickle(*ref.sptr()); graph = HeteroUnpickle(*ref.sptr());
break; break;
default: default:
LOG(FATAL) << "Version can only be 0 or 1."; LOG(FATAL) << "Version can only be 0 or 1 or 2.";
} }
*rv = HeteroGraphRef(graph); *rv = HeteroGraphRef(graph);
}); });
......
...@@ -218,7 +218,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -218,7 +218,7 @@ class UnitGraph : public BaseHeteroGraph {
* kDLGPU: invalid, will throw an error. * kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
void PinMemory_(); void PinMemory_() override;
/*! /*!
* \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph. * \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
......
...@@ -163,6 +163,31 @@ def test_pickling_subgraph(): ...@@ -163,6 +163,31 @@ def test_pickling_subgraph():
f1.close() f1.close()
f2.close() f2.close()
@unittest.skipIf(F._default_context_str != 'gpu', reason="Need GPU for pin")
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TensorFlow create graph on gpu when unpickle")
@parametrize_dtype
def test_pickling_is_pinned(idtype):
from copy import deepcopy
g = dgl.rand_graph(10, 20, idtype=idtype, device=F.cpu())
hg = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 2], [1, 0]),
('developer', 'develops', 'game'): ([0, 1], [0, 1])
}, idtype=idtype, device=F.cpu())
for graph in [g, hg]:
assert not graph.is_pinned()
graph.pin_memory_()
assert graph.is_pinned()
pg = _reconstruct_pickle(graph)
assert pg.is_pinned()
pg.unpin_memory_()
dg = deepcopy(graph)
assert dg.is_pinned()
dg.unpin_memory_()
graph.unpin_memory_()
if __name__ == '__main__': if __name__ == '__main__':
test_pickling_index() test_pickling_index()
test_pickling_graph_index() test_pickling_graph_index()
...@@ -172,3 +197,4 @@ if __name__ == '__main__': ...@@ -172,3 +197,4 @@ if __name__ == '__main__':
test_pickling_batched_graph() test_pickling_batched_graph()
test_pickling_heterograph() test_pickling_heterograph()
test_pickling_batched_heterograph() test_pickling_batched_heterograph()
test_pickling_is_pinned()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment