Unverified Commit 18a26fcf authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data] Support HeteroGraph save/load (#1526)



* 111

* add history version test

* fix

* 111

* save

* ``

* fix1

* 111

* add save heterograph

* lint

* lint

* add tests

* minor fix

* fix

* docs

* add format tets

* use unique_ptr

* fix

* fix interface

* 111

* 111

* fix

* lint

* fix

* add support to s3

* fix

* fix

* fix leak

* fix

* fix docs

* fix

* linlt

* fix

* fix

* fix

* address comment

* address comment
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent dd8d5289
......@@ -1442,11 +1442,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
int64_t format_code;
CHECK(fs->Read(&format_code)) << "Invalid format";
restrict_format_ = static_cast<SparseFormat>(format_code);
int64_t save_format_code, restrict_format_code;
CHECK(fs->Read(&save_format_code)) << "Invalid format";
CHECK(fs->Read(&restrict_format_code)) << "Invalid format";
restrict_format_ = static_cast<SparseFormat>(restrict_format_code);
auto save_format = static_cast<SparseFormat>(save_format_code);
switch (restrict_format_) {
switch (save_format) {
case SparseFormat::kCOO:
fs->Read(&coo_);
break;
......@@ -1473,6 +1475,7 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
// sparse matrix
auto avail_fmt = SelectFormat(SparseFormat::kAny);
fs->Write(static_cast<int64_t>(avail_fmt));
fs->Write(static_cast<int64_t>(restrict_format_));
switch (avail_fmt) {
case SparseFormat::kCOO:
fs->Write(GetCOO());
......
......@@ -4,6 +4,7 @@ import scipy as sp
import time
import tempfile
import os
import pytest
from dgl import DGLGraph
import dgl
......@@ -13,30 +14,33 @@ from dgl.data.utils import save_graphs, load_graphs, load_labels, save_tensors,
np.random.seed(44)
def generate_rand_graph(n):
def generate_rand_graph(n, is_hetero):
arr = (sp.sparse.random(n, n, density=0.1,
format='coo') != 0).astype(np.int64)
return DGLGraph(arr, readonly=True)
if is_hetero:
return dgl.graph(arr)
else:
return DGLGraph(arr, readonly=True)
def construct_graph(n, readonly=True):
def construct_graph(n, is_hetero):
g_list = []
for i in range(n):
g = generate_rand_graph(30)
g = generate_rand_graph(30, is_hetero)
g.edata['e1'] = F.randn((g.number_of_edges(), 32))
g.edata['e2'] = F.ones((g.number_of_edges(), 32))
g.ndata['n1'] = F.randn((g.number_of_nodes(), 64))
g.readonly(i % 2 == 0)
g_list.append(g)
return g_list
def test_graph_serialize_with_feature():
@pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_with_feature(is_hetero):
num_graphs = 100
t0 = time.time()
g_list = construct_graph(num_graphs)
g_list = construct_graph(num_graphs, is_hetero)
t1 = time.time()
......@@ -68,17 +72,13 @@ def test_graph_serialize_with_feature():
assert F.allclose(load_g.edata['e2'], g_list[idx].edata['e2'])
assert F.allclose(load_g.ndata['n1'], g_list[idx].ndata['n1'])
t4 = time.time()
bg = dgl.batch(loadg_list)
t5 = time.time()
print("Batch time: {} s".format(t5 - t4))
os.unlink(path)
def test_graph_serialize_without_feature():
@pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_without_feature(is_hetero):
num_graphs = 100
g_list = [generate_rand_graph(30) for _ in range(num_graphs)]
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
......@@ -102,10 +102,10 @@ def test_graph_serialize_without_feature():
os.unlink(path)
def test_graph_serialize_with_labels():
@pytest.mark.parametrize('is_hetero', [True, False])
def test_graph_serialize_with_labels(is_hetero):
num_graphs = 100
g_list = [generate_rand_graph(30) for _ in range(num_graphs)]
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
labels = {"label": F.zeros((num_graphs, 1))}
# create a temporary file and immediately release it so DGL can open it.
......@@ -162,6 +162,7 @@ def test_serialize_tensors():
os.unlink(path)
def test_serialize_empty_dict():
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
......@@ -174,13 +175,135 @@ def test_serialize_empty_dict():
load_tensor_dict = load_tensors(path)
assert isinstance(load_tensor_dict, dict)
assert len(load_tensor_dict) == 0
assert len(load_tensor_dict) == 0
os.unlink(path)
def test_load_old_files1():
loadg_list, _ = load_graphs(os.path.join(
os.path.dirname(__file__), "data/1.bin"))
idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load(
os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True)
load_g = loadg_list[idx]
load_edges = load_g.all_edges('uv', 'eid')
assert np.allclose(F.asnumpy(load_edges[0]), edge0)
assert np.allclose(F.asnumpy(load_edges[1]), edge1)
assert np.allclose(F.asnumpy(load_g.edata['e1']), edata_e1)
assert np.allclose(F.asnumpy(load_g.edata['e2']), edata_e2)
assert np.allclose(F.asnumpy(load_g.ndata['n1']), ndata_n1)
def test_load_old_files2():
loadg_list, labels0 = load_graphs(os.path.join(
os.path.dirname(__file__), "data/2.bin"))
labels1 = load_labels(os.path.join(
os.path.dirname(__file__), "data/2.bin"))
idx, edges0, edges1, np_labels = np.load(os.path.join(
os.path.dirname(__file__), "data/2.npy"), allow_pickle=True)
assert np.allclose(F.asnumpy(labels0['label']), np_labels)
assert np.allclose(F.asnumpy(labels1['label']), np_labels)
load_g = loadg_list[idx]
load_edges = load_g.all_edges('uv', 'eid')
assert np.allclose(F.asnumpy(load_edges[0]), edges0)
assert np.allclose(F.asnumpy(load_edges[1]), edges1)
def create_heterographs(index_dtype):
g_x = dgl.graph(([0, 1, 2], [1, 2, 3]), 'user',
'follows', index_dtype=index_dtype, restrict_format='any')
g_y = dgl.graph(([0, 2], [2, 3]), 'user', 'knows', index_dtype=index_dtype, restrict_format='csr')
g_x.nodes['user'].data['h'] = F.randn((4, 3))
g_x.edges['follows'].data['w'] = F.randn((3, 2))
g_y.nodes['user'].data['hh'] = F.ones((4, 5))
g_y.edges['knows'].data['ww'] = F.randn((2, 10))
g = dgl.hetero_from_relations([g_x, g_y])
return [g, g_x, g_y]
def test_deserialize_old_heterograph_file():
path = os.path.join(
os.path.dirname(__file__), "data/hetero1.bin")
g_list, label_dict = load_graphs(path)
assert g_list[0].idtype == F.int64
assert g_list[3].idtype == F.int32
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[5].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
assert F.allclose(label_dict['graph_label'], F.ones(54))
def create_old_heterograph_files():
path = os.path.join(
os.path.dirname(__file__), "data/hetero1.bin")
g_list0 = create_heterographs("int64") + create_heterographs("int32")
labels_dict = {"graph_label": F.ones(54)}
save_graphs(path, g_list0, labels_dict)
def test_serialize_heterograph():
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
g_list0 = create_heterographs("int64") + create_heterographs("int32")
save_graphs(path, g_list0)
g_list, _ = load_graphs(path)
assert g_list[0].idtype == F.int64
assert g_list[1].restrict_format() == 'any'
assert g_list[2].restrict_format() == 'csr'
assert g_list[3].idtype == F.int32
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[5].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
for i in range(len(g_list)):
assert g_list[i].ntypes == g_list0[i].ntypes
assert g_list[i].etypes == g_list0[i].etypes
os.unlink(path)
@pytest.mark.skip(reason="lack of permission on CI")
def test_serialize_heterograph_s3():
path = "s3://dglci-data-test/graph2.bin"
g_list0 = create_heterographs("int64") + create_heterographs("int32")
save_graphs(path, g_list0)
g_list = load_graphs(path, [0, 2, 5])
assert g_list[0].idtype == F.int64
assert g_list[1].restrict_format() == 'csr'
assert np.allclose(
F.asnumpy(g_list[1].nodes['user'].data['hh']), np.ones((4, 5)))
assert np.allclose(
F.asnumpy(g_list[2].nodes['user'].data['hh']), np.ones((4, 5)))
edges = g_list[0]['follows'].edges()
assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
if __name__ == "__main__":
test_graph_serialize_with_feature()
test_graph_serialize_without_feature()
test_graph_serialize_with_labels()
pass
test_graph_serialize_with_feature(True)
test_graph_serialize_with_feature(False)
test_graph_serialize_without_feature(True)
test_graph_serialize_without_feature(False)
test_graph_serialize_with_labels(True)
test_graph_serialize_with_labels(False)
test_serialize_tensors()
test_serialize_empty_dict()
\ No newline at end of file
test_serialize_empty_dict()
test_load_old_files1()
test_load_old_files2()
test_serialize_heterograph()
# test_serialize_heterograph_s3()
test_deserialize_old_heterograph_file()
# create_old_heterograph_files()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment