"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e4b056fe652536ac89ff2c98e36b2d3685cbccd2"
Unverified Commit 53629082 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix save/load for old version files (#1994)

* load hetero

* fix

* more error message

* add warning

* add alias to dgl namespace

* fix tests
parent 69f5869f
...@@ -32,6 +32,7 @@ from .traversal import * ...@@ -32,6 +32,7 @@ from .traversal import *
from .transform import * from .transform import *
from .propagate import * from .propagate import *
from .random import * from .random import *
from .data.utils import save_graphs, load_graphs
from ._deprecate.graph import DGLGraph as DGLGraphStale from ._deprecate.graph import DGLGraph as DGLGraphStale
from ._deprecate.nodeflow import * from ._deprecate.nodeflow import *
"""For Graph Serialization""" """For Graph Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
import os import os
from .._deprecate.graph import DGLGraph from ..base import dgl_warning
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
from .heterograph_serialize import HeteroGraphData, save_heterographs from .heterograph_serialize import save_heterographs
_init_api("dgl.data.graph_serialize") _init_api("dgl.data.graph_serialize")
...@@ -30,7 +30,7 @@ class GraphData(ObjectBase): ...@@ -30,7 +30,7 @@ class GraphData(ObjectBase):
"""GraphData Object""" """GraphData Object"""
@staticmethod @staticmethod
def create(g: DGLGraph): def create(g):
"""Create GraphData""" """Create GraphData"""
# TODO(zihao): support serialize batched graph in the future. # TODO(zihao): support serialize batched graph in the future.
assert g.batch_size == 1, "Batched DGLGraph is not supported for serialization" assert g.batch_size == 1, "Batched DGLGraph is not supported for serialization"
...@@ -52,9 +52,10 @@ class GraphData(ObjectBase): ...@@ -52,9 +52,10 @@ class GraphData(ObjectBase):
return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors) return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)
def get_graph(self): def get_graph(self):
"""Get DGLGraph from GraphData""" """Get DGLHeteroGraph from GraphData"""
ghandle = _CAPI_GDataGraphHandle(self) ghandle = _CAPI_GDataGraphHandle(self)
g = DGLGraph(graph_data=ghandle, readonly=True) hgi =_CAPI_DGLAsHeteroGraph(ghandle)
g = DGLHeteroGraph(hgi, ['_U'], ['_E'])
node_tensors_items = _CAPI_GDataNodeTensors(self).items() node_tensors_items = _CAPI_GDataNodeTensors(self).items()
edge_tensors_items = _CAPI_GDataEdgeTensors(self).items() edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
for k, v in node_tensors_items: for k, v in node_tensors_items:
...@@ -66,7 +67,7 @@ class GraphData(ObjectBase): ...@@ -66,7 +67,7 @@ class GraphData(ObjectBase):
def save_graphs(filename, g_list, labels=None): def save_graphs(filename, g_list, labels=None):
r""" r"""
Save DGLGraphs/DGLHeteroGraph and graph labels to file Save DGLGraphs and graph labels to file
Parameters Parameters
---------- ----------
...@@ -104,28 +105,13 @@ def save_graphs(filename, g_list, labels=None): ...@@ -104,28 +105,13 @@ def save_graphs(filename, g_list, labels=None):
os.makedirs(f_path) os.makedirs(f_path)
g_sample = g_list[0] if isinstance(g_list, list) else g_list g_sample = g_list[0] if isinstance(g_list, list) else g_list
if isinstance(g_sample, DGLGraph): if type(g_sample) == DGLHeteroGraph: # Doesn't support DGLHeteroGraph's derived class
save_dglgraphs(filename, g_list, labels)
elif isinstance(g_sample, DGLHeteroGraph):
save_heterographs(filename, g_list, labels) save_heterographs(filename, g_list, labels)
else: else:
raise Exception( raise Exception(
"Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs/DGLHeteroGraphs") "Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs/DGLHeteroGraphs")
def save_dglgraphs(filename, g_list, labels=None):
"""Internal function to save DGLGraphs"""
if isinstance(g_list, DGLGraph):
g_list = [g_list]
if (labels is not None) and (len(labels) != 0):
label_dict = dict()
for key, value in labels.items():
label_dict[key] = F.zerocopy_to_dgl_ndarray(value)
else:
label_dict = None
gdata_list = [GraphData.create(g) for g in g_list]
_CAPI_SaveDGLGraphs_V0(filename, gdata_list, label_dict)
def load_graphs(filename, idx_list=None): def load_graphs(filename, idx_list=None):
""" """
...@@ -162,6 +148,9 @@ def load_graphs(filename, idx_list=None): ...@@ -162,6 +148,9 @@ def load_graphs(filename, idx_list=None):
version = _CAPI_GetFileVersion(filename) version = _CAPI_GetFileVersion(filename)
if version == 1: if version == 1:
dgl_warning(
"You are loading a graph file saved by old version of dgl. \
Please consider saving it again with the current format.")
return load_graph_v1(filename, idx_list) return load_graph_v1(filename, idx_list)
elif version == 2: elif version == 2:
return load_graph_v2(filename, idx_list) return load_graph_v2(filename, idx_list)
...@@ -191,7 +180,6 @@ def load_graph_v1(filename, idx_list=None): ...@@ -191,7 +180,6 @@ def load_graph_v1(filename, idx_list=None):
return [gdata.get_graph() for gdata in metadata.graph_data], label_dict return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
def load_labels(filename): def load_labels(filename):
""" """
Load label dict from file Load label dict from file
......
...@@ -24,6 +24,7 @@ def save_heterographs(filename, g_list, labels): ...@@ -24,6 +24,7 @@ def save_heterographs(filename, g_list, labels):
labels = {} labels = {}
if isinstance(g_list, DGLHeteroGraph): if isinstance(g_list, DGLHeteroGraph):
g_list = [g_list] g_list = [g_list]
assert all([type(g) == DGLHeteroGraph for g in g_list]), "Invalid DGLHeteroGraph in g_list argument"
gdata_list = [HeteroGraphData.create(g) for g in g_list] gdata_list = [HeteroGraphData.create(g) for g in g_list]
_CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels)) _CAPI_SaveHeteroGraphData(filename, gdata_list, tensor_dict_to_ndarray_dict(labels))
......
...@@ -144,6 +144,14 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1") ...@@ -144,6 +144,14 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
*rv = LoadDGLGraphs(filename, idx_list, onlyMeta); *rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLAsHeteroGraph")
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ig) << "graph is not readonly";
*rv = HeteroGraphRef(ig->AsHeteroGraph());
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
......
...@@ -10,7 +10,7 @@ import unittest ...@@ -10,7 +10,7 @@ import unittest
from dgl import DGLGraph from dgl import DGLGraph
import dgl import dgl
import dgl.ndarray as nd import dgl.ndarray as nd
from dgl.data.utils import save_graphs, load_graphs, load_labels, save_tensors, load_tensors from dgl.data.utils import load_labels, save_tensors, load_tensors
np.random.seed(44) np.random.seed(44)
...@@ -51,11 +51,11 @@ def test_graph_serialize_with_feature(is_hetero): ...@@ -51,11 +51,11 @@ def test_graph_serialize_with_feature(is_hetero):
path = f.name path = f.name
f.close() f.close()
save_graphs(path, g_list) dgl.save_graphs(path, g_list)
t2 = time.time() t2 = time.time()
idx_list = np.random.permutation(np.arange(num_graphs)).tolist() idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
loadg_list, _ = load_graphs(path, idx_list) loadg_list, _ = dgl.load_graphs(path, idx_list)
t3 = time.time() t3 = time.time()
idx = idx_list[0] idx = idx_list[0]
...@@ -88,10 +88,10 @@ def test_graph_serialize_without_feature(is_hetero): ...@@ -88,10 +88,10 @@ def test_graph_serialize_without_feature(is_hetero):
path = f.name path = f.name
f.close() f.close()
save_graphs(path, g_list) dgl.save_graphs(path, g_list)
idx_list = np.random.permutation(np.arange(num_graphs)).tolist() idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
loadg_list, _ = load_graphs(path, idx_list) loadg_list, _ = dgl.load_graphs(path, idx_list)
idx = idx_list[0] idx = idx_list[0]
load_g = loadg_list[0] load_g = loadg_list[0]
...@@ -117,10 +117,10 @@ def test_graph_serialize_with_labels(is_hetero): ...@@ -117,10 +117,10 @@ def test_graph_serialize_with_labels(is_hetero):
path = f.name path = f.name
f.close() f.close()
save_graphs(path, g_list, labels) dgl.save_graphs(path, g_list, labels)
idx_list = np.random.permutation(np.arange(num_graphs)).tolist() idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
loadg_list, l_labels0 = load_graphs(path, idx_list) loadg_list, l_labels0 = dgl.load_graphs(path, idx_list)
l_labels = load_labels(path) l_labels = load_labels(path)
assert F.allclose(l_labels['label'], labels['label']) assert F.allclose(l_labels['label'], labels['label'])
assert F.allclose(l_labels0['label'], labels['label']) assert F.allclose(l_labels0['label'], labels['label'])
...@@ -185,7 +185,7 @@ def test_serialize_empty_dict(): ...@@ -185,7 +185,7 @@ def test_serialize_empty_dict():
def test_load_old_files1(): def test_load_old_files1():
loadg_list, _ = load_graphs(os.path.join( loadg_list, _ = dgl.load_graphs(os.path.join(
os.path.dirname(__file__), "data/1.bin")) os.path.dirname(__file__), "data/1.bin"))
idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load( idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load(
os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True) os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True)
...@@ -201,7 +201,7 @@ def test_load_old_files1(): ...@@ -201,7 +201,7 @@ def test_load_old_files1():
def test_load_old_files2(): def test_load_old_files2():
loadg_list, labels0 = load_graphs(os.path.join( loadg_list, labels0 = dgl.load_graphs(os.path.join(
os.path.dirname(__file__), "data/2.bin")) os.path.dirname(__file__), "data/2.bin"))
labels1 = load_labels(os.path.join( labels1 = load_labels(os.path.join(
os.path.dirname(__file__), "data/2.bin")) os.path.dirname(__file__), "data/2.bin"))
...@@ -211,6 +211,7 @@ def test_load_old_files2(): ...@@ -211,6 +211,7 @@ def test_load_old_files2():
assert np.allclose(F.asnumpy(labels1['label']), np_labels) assert np.allclose(F.asnumpy(labels1['label']), np_labels)
load_g = loadg_list[idx] load_g = loadg_list[idx]
print(load_g)
load_edges = load_g.all_edges('uv', 'eid') load_edges = load_g.all_edges('uv', 'eid')
assert np.allclose(F.asnumpy(load_edges[0]), edges0) assert np.allclose(F.asnumpy(load_edges[0]), edges0)
assert np.allclose(F.asnumpy(load_edges[1]), edges1) assert np.allclose(F.asnumpy(load_edges[1]), edges1)
...@@ -242,7 +243,7 @@ def create_heterographs2(idtype): ...@@ -242,7 +243,7 @@ def create_heterographs2(idtype):
def test_deserialize_old_heterograph_file(): def test_deserialize_old_heterograph_file():
path = os.path.join( path = os.path.join(
os.path.dirname(__file__), "data/hetero1.bin") os.path.dirname(__file__), "data/hetero1.bin")
g_list, label_dict = load_graphs(path) g_list, label_dict = dgl.load_graphs(path)
assert g_list[0].idtype == F.int64 assert g_list[0].idtype == F.int64
assert g_list[3].idtype == F.int32 assert g_list[3].idtype == F.int32
assert np.allclose( assert np.allclose(
...@@ -259,7 +260,7 @@ def create_old_heterograph_files(): ...@@ -259,7 +260,7 @@ def create_old_heterograph_files():
os.path.dirname(__file__), "data/hetero1.bin") os.path.dirname(__file__), "data/hetero1.bin")
g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32) g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32)
labels_dict = {"graph_label": F.ones(54)} labels_dict = {"graph_label": F.ones(54)}
save_graphs(path, g_list0, labels_dict) dgl.save_graphs(path, g_list0, labels_dict)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
...@@ -268,9 +269,9 @@ def test_serialize_heterograph(): ...@@ -268,9 +269,9 @@ def test_serialize_heterograph():
path = f.name path = f.name
f.close() f.close()
g_list0 = create_heterographs2(F.int64) + create_heterographs2(F.int32) g_list0 = create_heterographs2(F.int64) + create_heterographs2(F.int32)
save_graphs(path, g_list0) dgl.save_graphs(path, g_list0)
g_list, _ = load_graphs(path) g_list, _ = dgl.load_graphs(path)
assert g_list[0].idtype == F.int64 assert g_list[0].idtype == F.int64
assert len(g_list[0].canonical_etypes) == 3 assert len(g_list[0].canonical_etypes) == 3
for i in range(len(g_list0)): for i in range(len(g_list0)):
...@@ -302,9 +303,9 @@ def test_serialize_heterograph(): ...@@ -302,9 +303,9 @@ def test_serialize_heterograph():
def test_serialize_heterograph_s3(): def test_serialize_heterograph_s3():
path = "s3://dglci-data-test/graph2.bin" path = "s3://dglci-data-test/graph2.bin"
g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32) g_list0 = create_heterographs(F.int64) + create_heterographs(F.int32)
save_graphs(path, g_list0) dgl.save_graphs(path, g_list0)
g_list = load_graphs(path, [0, 2, 5]) g_list = dgl.load_graphs(path, [0, 2, 5])
assert g_list[0].idtype == F.int64 assert g_list[0].idtype == F.int64
#assert g_list[1].restrict_format() == 'csr' #assert g_list[1].restrict_format() == 'csr'
assert np.allclose( assert np.allclose(
...@@ -327,8 +328,8 @@ if __name__ == "__main__": ...@@ -327,8 +328,8 @@ if __name__ == "__main__":
#test_graph_serialize_with_labels(False) #test_graph_serialize_with_labels(False)
#test_serialize_tensors() #test_serialize_tensors()
#test_serialize_empty_dict() #test_serialize_empty_dict()
#test_load_old_files1() # test_load_old_files1()
#test_load_old_files2() test_load_old_files2()
#test_serialize_heterograph() #test_serialize_heterograph()
#test_serialize_heterograph_s3() #test_serialize_heterograph_s3()
#test_deserialize_old_heterograph_file() #test_deserialize_old_heterograph_file()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment