Unverified Commit 77c84834 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving serialize tests (#6116)

parent d19887cd
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import tempfile import tempfile
import time import time
import unittest import unittest
import warnings
import backend as F import backend as F
...@@ -10,26 +11,22 @@ import dgl.ndarray as nd ...@@ -10,26 +11,22 @@ import dgl.ndarray as nd
import numpy as np import numpy as np
import pytest import pytest
import scipy as sp import scipy as sp
from dgl import DGLGraph
from dgl.data.utils import load_labels, load_tensors, save_tensors from dgl.data.utils import load_labels, load_tensors, save_tensors
np.random.seed(44) np.random.seed(44)
def generate_rand_graph(n, is_hetero): def generate_rand_graph(n):
arr = (sp.sparse.random(n, n, density=0.1, format="coo") != 0).astype( arr = (sp.sparse.random(n, n, density=0.1, format="coo") != 0).astype(
np.int64 np.int64
) )
if is_hetero: return dgl.from_scipy(arr)
return dgl.from_scipy(arr)
else:
return DGLGraph(arr, readonly=True)
def construct_graph(n, is_hetero): def construct_graph(n):
g_list = [] g_list = []
for i in range(n): for _ in range(n):
g = generate_rand_graph(30, is_hetero) g = generate_rand_graph(30)
g.edata["e1"] = F.randn((g.num_edges(), 32)) g.edata["e1"] = F.randn((g.num_edges(), 32))
g.edata["e2"] = F.ones((g.num_edges(), 32)) g.edata["e2"] = F.ones((g.num_edges(), 32))
g.ndata["n1"] = F.randn((g.num_nodes(), 64)) g.ndata["n1"] = F.randn((g.num_nodes(), 64))
...@@ -38,13 +35,12 @@ def construct_graph(n, is_hetero): ...@@ -38,13 +35,12 @@ def construct_graph(n, is_hetero):
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented") @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
@pytest.mark.parametrize("is_hetero", [True, False]) def test_graph_serialize_with_feature():
def test_graph_serialize_with_feature(is_hetero):
num_graphs = 100 num_graphs = 100
t0 = time.time() t0 = time.time()
g_list = construct_graph(num_graphs, is_hetero) g_list = construct_graph(num_graphs)
t1 = time.time() t1 = time.time()
...@@ -80,10 +76,9 @@ def test_graph_serialize_with_feature(is_hetero): ...@@ -80,10 +76,9 @@ def test_graph_serialize_with_feature(is_hetero):
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented") @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
@pytest.mark.parametrize("is_hetero", [True, False]) def test_graph_serialize_without_feature():
def test_graph_serialize_without_feature(is_hetero):
num_graphs = 100 num_graphs = 100
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)] g_list = [generate_rand_graph(30) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
...@@ -109,10 +104,9 @@ def test_graph_serialize_without_feature(is_hetero): ...@@ -109,10 +104,9 @@ def test_graph_serialize_without_feature(is_hetero):
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented") @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
@pytest.mark.parametrize("is_hetero", [True, False]) def test_graph_serialize_with_labels():
def test_graph_serialize_with_labels(is_hetero):
num_graphs = 100 num_graphs = 100
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)] g_list = [generate_rand_graph(30) for _ in range(num_graphs)]
labels = {"label": F.zeros((num_graphs, 1))} labels = {"label": F.zeros((num_graphs, 1))}
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
...@@ -191,10 +185,14 @@ def test_serialize_empty_dict(): ...@@ -191,10 +185,14 @@ def test_serialize_empty_dict():
os.unlink(path) os.unlink(path)
def load_old_files(files):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return dgl.load_graphs(os.path.join(os.path.dirname(__file__), files))
def test_load_old_files1(): def test_load_old_files1():
loadg_list, _ = dgl.load_graphs( loadg_list, _ = load_old_files("data/1.bin")
os.path.join(os.path.dirname(__file__), "data/1.bin")
)
idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load( idx, num_nodes, edge0, edge1, edata_e1, edata_e2, ndata_n1 = np.load(
os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True os.path.join(os.path.dirname(__file__), "data/1.npy"), allow_pickle=True
) )
...@@ -210,9 +208,7 @@ def test_load_old_files1(): ...@@ -210,9 +208,7 @@ def test_load_old_files1():
def test_load_old_files2(): def test_load_old_files2():
loadg_list, labels0 = dgl.load_graphs( loadg_list, labels0 = load_old_files("data/2.bin")
os.path.join(os.path.dirname(__file__), "data/2.bin")
)
labels1 = load_labels(os.path.join(os.path.dirname(__file__), "data/2.bin")) labels1 = load_labels(os.path.join(os.path.dirname(__file__), "data/2.bin"))
idx, edges0, edges1, np_labels = np.load( idx, edges0, edges1, np_labels = np.load(
os.path.join(os.path.dirname(__file__), "data/2.npy"), allow_pickle=True os.path.join(os.path.dirname(__file__), "data/2.npy"), allow_pickle=True
...@@ -365,7 +361,6 @@ def test_serialize_heterograph_s3(): ...@@ -365,7 +361,6 @@ def test_serialize_heterograph_s3():
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented") @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
@pytest.mark.parametrize("is_hetero", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"formats", "formats",
[ [
...@@ -378,9 +373,9 @@ def test_serialize_heterograph_s3(): ...@@ -378,9 +373,9 @@ def test_serialize_heterograph_s3():
["coo", "csr", "csc"], ["coo", "csr", "csc"],
], ],
) )
def test_graph_serialize_with_formats(is_hetero, formats): def test_graph_serialize_with_formats(formats):
num_graphs = 100 num_graphs = 100
g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)] g_list = [generate_rand_graph(30) for _ in range(num_graphs)]
# create a temporary file and immediately release it so DGL can open it. # create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False) f = tempfile.NamedTemporaryFile(delete=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment