Unverified Commit 20ec7bb0 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug?] Fixing compatibility of pickles across DGL versions (#1507)



* pickle compatibility before 0.4.2

* fixing for utils.Index

* add TODOs

* fix

* fix

* more compatibility checks for DGLGraph and DGLHeteroGraph objects

* lint
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent c2e61ce1
...@@ -1050,6 +1050,15 @@ class DGLGraph(DGLBaseGraph): ...@@ -1050,6 +1050,15 @@ class DGLGraph(DGLBaseGraph):
# set parent if the graph is a subgraph. # set parent if the graph is a subgraph.
self._parent = parent self._parent = parent
def __setstate__(self, state):
# Compatibility with pickles from DGL 0.4.2-
if '_batch_num_nodes' not in state:
state = state.copy()
state.setdefault('_batch_num_nodes', None)
state.setdefault('_batch_num_edges', None)
state.setdefault('_parent', None)
self.__dict__.update(state)
def _create_subgraph(self, sgi, induced_nodes, induced_edges): def _create_subgraph(self, sgi, induced_nodes, induced_edges):
"""Internal function to create a subgraph from index.""" """Internal function to create a subgraph from index."""
subg = DGLGraph(graph_data=sgi.graph, subg = DGLGraph(graph_data=sgi.graph,
......
...@@ -50,7 +50,16 @@ class GraphIndex(ObjectBase): ...@@ -50,7 +50,16 @@ class GraphIndex(ObjectBase):
"""The pickle state of GraphIndex is defined as a triplet """The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, readonly, src_nodes, dst_nodes) (number_of_nodes, readonly, src_nodes, dst_nodes)
""" """
# Pickle compatibility check
# TODO: we should store a storage version number in later releases.
if isinstance(state, tuple) and len(state) == 5:
dgl_warning("The object is pickled pre-0.4.2. Multigraph flag is ignored in 0.4.3")
num_nodes, _, readonly, src, dst = state
elif isinstance(state, tuple) and len(state) == 4:
# post-0.4.3.
num_nodes, readonly, src, dst = state num_nodes, readonly, src, dst = state
else:
raise IOError('Unrecognized storage format.')
self._cache = {} self._cache = {}
self._readonly = readonly self._readonly = readonly
......
...@@ -275,7 +275,19 @@ class DGLHeteroGraph(object): ...@@ -275,7 +275,19 @@ class DGLHeteroGraph(object):
return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames
def __setstate__(self, state): def __setstate__(self, state):
# Compatibility check
# TODO: version the storage
if isinstance(state, tuple) and len(state) == 5:
# DGL 0.4.3+
self._init(*state) self._init(*state)
elif isinstance(state, dict):
# DGL 0.4.2-
dgl_warning("The object is pickled with DGL version 0.4.2-. "
"Some of the original attributes are ignored.")
self._init(state['_graph'], state['_ntypes'], state['_etypes'], state['_node_frames'],
state['_edge_frames'])
else:
raise IOError("Unrecognized pickle format.")
def _get_msg_index(self, etid): def _get_msg_index(self, etid):
"""Internal function for getting the message index array of the given edge type id.""" """Internal function for getting the message index array of the given edge type id."""
......
...@@ -29,7 +29,28 @@ class HeteroGraphIndex(ObjectBase): ...@@ -29,7 +29,28 @@ class HeteroGraphIndex(ObjectBase):
def __setstate__(self, state): def __setstate__(self, state):
self._cache = {} self._cache = {}
# Pickle compatibility check
# TODO: we should store a storage version number in later releases.
if isinstance(state, HeteroPickleStates):
# post-0.4.3
self.__init_handle_by_constructor__(_CAPI_DGLHeteroUnpickle, state) self.__init_handle_by_constructor__(_CAPI_DGLHeteroUnpickle, state)
elif isinstance(state, tuple) and len(state) == 3:
# pre-0.4.2
metagraph, number_of_nodes, edges = state
self._cache = {}
# loop over etypes and recover unit graphs
rel_graphs = []
for i, edges_per_type in enumerate(edges):
src_ntype, dst_ntype = metagraph.find_edge(i)
num_src = number_of_nodes[src_ntype]
num_dst = number_of_nodes[dst_ntype]
src_id, dst_id, _ = edges_per_type
rel_graphs.append(create_unitgraph_from_coo(
1 if src_ntype == dst_ntype else 2, num_src, num_dst, src_id, dst_id, "any"))
self.__init_handle_by_constructor__(
_CAPI_DGLHeteroCreateHeteroGraph, metagraph, rel_graphs)
@property @property
def metagraph(self): def metagraph(self):
......
...@@ -5,7 +5,7 @@ from collections.abc import Mapping, Iterable ...@@ -5,7 +5,7 @@ from collections.abc import Mapping, Iterable
from functools import wraps from functools import wraps
import numpy as np import numpy as np
from .base import DGLError from .base import DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import ndarray as nd from . import ndarray as nd
...@@ -147,8 +147,17 @@ class Index(object): ...@@ -147,8 +147,17 @@ class Index(object):
return self.tousertensor(), self.dtype return self.tousertensor(), self.dtype
def __setstate__(self, state): def __setstate__(self, state):
# Pickle compatibility check
# TODO: we should store a storage version number in later releases.
if isinstance(state, tuple) and len(state) == 2:
# post-0.4.4
data, self.dtype = state data, self.dtype = state
self._initialize_data(data) self._initialize_data(data)
else:
# pre-0.4.3
dgl_warning("The object is pickled before 0.4.3. Setting dtype of graph to int64")
self.dtype = 'int64'
self._initialize_data(state)
def get_items(self, index): def get_items(self, index):
"""Return values at given positions of an Index """Return values at given positions of an Index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment