Unverified Commit 18a26fcf authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data] Support HeteroGraph save/load (#1526)



* 111

* add history version test

* fix

* 111

* save

* ``

* fix1

* 111

* add save heterograph

* lint

* lint

* add tests

* minor fix

* fix

* docs

* add format tets

* use unique_ptr

* fix

* fix interface

* 111

* 111

* fix

* lint

* fix

* add support to s3

* fix

* fix

* fix leak

* fix

* fix docs

* fix

* linlt

* fix

* fix

* fix

* address comment

* address comment
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent dd8d5289
...@@ -1442,11 +1442,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1442,11 +1442,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data"; CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
int64_t format_code; int64_t save_format_code, restrict_format_code;
CHECK(fs->Read(&format_code)) << "Invalid format"; CHECK(fs->Read(&save_format_code)) << "Invalid format";
restrict_format_ = static_cast<SparseFormat>(format_code); CHECK(fs->Read(&restrict_format_code)) << "Invalid format";
restrict_format_ = static_cast<SparseFormat>(restrict_format_code);
auto save_format = static_cast<SparseFormat>(save_format_code);
switch (restrict_format_) { switch (save_format) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
fs->Read(&coo_); fs->Read(&coo_);
break; break;
...@@ -1473,6 +1475,7 @@ void UnitGraph::Save(dmlc::Stream* fs) const { ...@@ -1473,6 +1475,7 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
// sparse matrix // sparse matrix
auto avail_fmt = SelectFormat(SparseFormat::kAny); auto avail_fmt = SelectFormat(SparseFormat::kAny);
fs->Write(static_cast<int64_t>(avail_fmt)); fs->Write(static_cast<int64_t>(avail_fmt));
fs->Write(static_cast<int64_t>(restrict_format_));
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
fs->Write(GetCOO()); fs->Write(GetCOO());
......
...@@ -4,6 +4,7 @@ import scipy as sp ...@@ -4,6 +4,7 @@ import scipy as sp
import time import time
import tempfile import tempfile
import os import os
import pytest
from dgl import DGLGraph from dgl import DGLGraph
import dgl import dgl
...@@ -13,30 +14,33 @@ from dgl.data.utils import save_graphs, load_graphs, load_labels, save_tensors, ...@@ -13,30 +14,33 @@ from dgl.data.utils import save_graphs, load_graphs, load_labels, save_tensors,
np.random.seed(44) np.random.seed(44)
def generate_rand_graph(n): def generate_rand_graph(n, is_hetero):
arr = (sp.sparse.random(n, n, density=0.1, arr = (sp.sparse.random(n, n, density=0.1,
format='coo') != 0).astype(np.int64) format='coo') != 0).astype(np.int64)
return DGLGraph(arr, readonly=True) if is_hetero:
return dgl.graph(arr)
else:
return DGLGraph(arr, readonly=True)
def construct_graph(n, readonly=True): def construct_graph(n, is_hetero):
g_list = [] g_list = []
for i in range(n): for i in range(n):
g = generate_rand_graph(30) g = generate_rand_graph(30, is_hetero)
g.edata['e1'] = F.randn((g.number_of_edges(), 32)) g.edata['e1'] = F.randn((g.number_of_edges(), 32))
g.edata['e2'] = F.ones((g.number_of_edges(), 32)) g.edata['e2'] = F.ones((g.number_of_edges(), 32))
g.ndata['n1'] = F.randn((g.number_of_nodes(), 64)) g.ndata['n1'] = F.randn((g.number_of_nodes(), 64))
g.readonly(i % 2 == 0)
g_list.append(g) g_list.append(g)
return g_list return g_list
def test_graph_serialize_with_feature(): @pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_with_feature(is_hetero):
num_graphs = 100 num_graphs = 100
t0 = time.time() t0 = time.time()
g_list = construct_graph(num_graphs) g_list = construct_graph(num_graphs, is_hetero)
t1 = time.time() t1 = time.time()
...@@ -68,17 +72,13 @@ def test_graph_serialize_with_feature(): ...@@ -68,17 +72,13 @@ def test_graph_serialize_with_feature():
assert F.allclose(load_g.edata['e2'], g_list[idx].edata['e2']) assert F.allclose(load_g.edata['e2'], g_list[idx].edata['e2'])
assert F.allclose(load_g.ndata['n1'], g_list[idx].ndata['n1']) assert F.allclose(load_g.ndata['n1'], g_list[idx].ndata['n1'])
t4 = time.time()
bg = dgl.batch(loadg_list)
t5 = time.time()
print("Batch time: {} s".format(t5 - t4))
os.unlink(path) os.unlink(path)
def test_graph_serialize_without_feature(): @pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_without_feature(is_hetero):
num_graphs = 100 num_graphs = 100
g_list = [generate_rand_graph(30) for _ in range(num_graphs)] g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
...@@ -102,10 +102,10 @@ def test_graph_serialize_without_feature(): ...@@ -102,10 +102,10 @@ def test_graph_serialize_without_feature():
os.unlink(path) os.unlink(path)
@pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_with_labels(): def test_graph_serialize_with_labels(is_hetero):
num_graphs = 100 num_graphs = 100
g_list = [generate_rand_graph(30) for _ in range(num_graphs)] g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
labels = {"label": F.zeros((num_graphs, 1))} labels = {"label": F.zeros((num_graphs, 1))}
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
...@@ -162,6 +162,7 @@ def test_serialize_tensors(): ...@@ -162,6 +162,7 @@ def test_serialize_tensors():
os.unlink(path) os.unlink(path)
def test_serialize_empty_dict(): def test_serialize_empty_dict():
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
...@@ -174,13 +175,135 @@ def test_serialize_empty_dict(): ...@@ -174,13 +175,135 @@ def test_serialize_empty_dict():
load_tensor_dict = load_tensors(path) load_tensor_dict = load_tensors(path)
assert isinstance(load_tensor_dict, dict) assert isinstance(load_tensor_dict, dict)
assert len(load_tensor_dict) == 0 assert len(load_tensor_dict) == 0
os.unlink(path) os.unlink(path)
def test_load_old_files1():
loadg_list, _ = load_graphs(os.path.join(
os.path.dirname(__file__), "data/1.bin"))
idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load(
os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True)
load_g = loadg_list[idx]
load_edges = load_g.all_edges('uv', 'eid')
assert np.allclose(F.asnumpy(load_edges[0]), edge0)
assert np.allclose(F.asnumpy(load_edges[1]), edge1)
assert np.allclose(F.asnumpy(load_g.edata['e1']), edata_e1)
assert np.allclose(F.asnumpy(load_g.edata['e2']), edata_e2)
assert np.allclose(F.asnumpy(load_g.ndata['n1']), ndata_n1)
def test_load_old_files2():
loadg_list, labels0 = load_graphs(os.path.join(
os.path.dirname(__file__), "data/2.bin"))
labels1 = load_labels(os.path.join(
os.path.dirname(__file__), "data/2.bin"))
idx, edges0, edges1, np_labels = np.load(os.path.join(
os.path.dirname(__file__), "data/2.npy"), allow_pickle=True)
assert np.allclose(F.asnumpy(labels0['label']), np_labels)
assert np.allclose(F.asnumpy(labels1['label']), np_labels)
load_g = loadg_list[idx]
load_edges = load_g.all_edges('uv', 'eid')
assert np.allclose(F.asnumpy(load_edges[0]), edges0)
assert np.allclose(F.asnumpy(load_edges[1]), edges1)
def create_heterographs(index_dtype):
g_x = dgl.graph(([0, 1, 2], [1, 2, 3]), 'user',
'follows', index_dtype=index_dtype, restrict_format='any')
g_y = dgl.graph(([0, 2], [2, 3]), 'user', 'knows', index_dtype=index_dtype, restrict_format='csr')
g_x.nodes['user'].data['h'] = F.randn((4, 3))
g_x.edges['follows'].data['w'] = F.randn((3, 2))
g_y.nodes['user'].data['hh'] = F.ones((4, 5))
g_y.edges['knows'].data['ww'] = F.randn((2, 10))
g = dgl.hetero_from_relations([g_x, g_y])
return [g, g_x, g_y]
def test_deserialize_old_heterograph_file():
path = os.path.join(
os.path.dirname(__file__), "data/hetero1.bin")
g_list, label_dict = load_graphs(path)
assert g_list[0].idtype == F.int64
assert g_list[3].idtype == F.int32
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[5].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
assert F.allclose(label_dict['graph_label'], F.ones(54))
def create_old_heterograph_files():
path = os.path.join(
os.path.dirname(__file__), "data/hetero1.bin")
g_list0 = create_heterographs("int64") + create_heterographs("int32")
labels_dict = {"graph_label": F.ones(54)}
save_graphs(path, g_list0, labels_dict)
def test_serialize_heterograph():
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
g_list0 = create_heterographs("int64") + create_heterographs("int32")
save_graphs(path, g_list0)
g_list, _ = load_graphs(path)
assert g_list[0].idtype == F.int64
assert g_list[1].restrict_format() == 'any'
assert g_list[2].restrict_format() == 'csr'
assert g_list[3].idtype == F.int32
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[5].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
for i in range(len(g_list)):
assert g_list[i].ntypes == g_list0[i].ntypes
assert g_list[i].etypes == g_list0[i].etypes
os.unlink(path)
@pytest.mark.skip(reason="lack of permission on CI")
def test_serialize_heterograph_s3():
path = "s3://dglci-data-test/graph2.bin"
g_list0 = create_heterographs("int64") + create_heterographs("int32")
save_graphs(path, g_list0)
g_list = load_graphs(path, [0, 2, 5])
assert g_list[0].idtype == F.int64
assert g_list[1].restrict_format() == 'csr'
assert np.allclose(
F.asnumpy(g_list[1].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
if __name__ == "__main__": if __name__ == "__main__":
test_graph_serialize_with_feature() pass
test_graph_serialize_without_feature() test_graph_serialize_with_feature(True)
test_graph_serialize_with_labels() test_graph_serialize_with_feature(False)
test_graph_serialize_without_feature(True)
test_graph_serialize_without_feature(False)
test_graph_serialize_with_labels(True)
test_graph_serialize_with_labels(False)
test_serialize_tensors() test_serialize_tensors()
test_serialize_empty_dict() test_serialize_empty_dict()
\ No newline at end of file test_load_old_files1()
test_load_old_files2()
test_serialize_heterograph()
# test_serialize_heterograph_s3()
test_deserialize_old_heterograph_file()
# create_old_heterograph_files()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment