Unverified Commit 8a830272 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Merge test utils. (#5440)



* merge

* format

* rename

* sort

* sort

* update

* update

* update

* Update tests/utils/checks.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 5421940a
...@@ -4,8 +4,7 @@ import backend as F ...@@ -4,8 +4,7 @@ import backend as F
import dgl import dgl
import networkx as nx import networkx as nx
import utils as U from utils import check_fail, parametrize_idtype
from pytests_utils import parametrize_idtype
def create_graph(idtype): def create_graph(idtype):
...@@ -87,7 +86,7 @@ def test_prop_edges_dfs(idtype): ...@@ -87,7 +86,7 @@ def test_prop_edges_dfs(idtype):
def test_prop_nodes_topo(idtype): def test_prop_nodes_topo(idtype):
# bi-directional chain # bi-directional chain
g = create_graph(idtype) g = create_graph(idtype)
assert U.check_fail(dgl.prop_nodes_topo, g) # has loop assert check_fail(dgl.prop_nodes_topo, g) # has loop
# tree # tree
tree = dgl.DGLGraph() tree = dgl.DGLGraph()
......
...@@ -6,8 +6,8 @@ import dgl ...@@ -6,8 +6,8 @@ import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
from pytests_utils.graph_cases import get_cases from utils.graph_cases import get_cases
@parametrize_idtype @parametrize_idtype
......
...@@ -4,7 +4,7 @@ import dgl ...@@ -4,7 +4,7 @@ import dgl
import numpy as np import numpy as np
import pytest import pytest
import scipy.sparse as ssp import scipy.sparse as ssp
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
if F.backend_name == "pytorch": if F.backend_name == "pytorch":
import torch import torch
......
...@@ -7,7 +7,7 @@ import networkx as nx ...@@ -7,7 +7,7 @@ import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
import scipy.sparse as ssp import scipy.sparse as ssp
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
D = 5 D = 5
......
...@@ -10,7 +10,7 @@ import dgl ...@@ -10,7 +10,7 @@ import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
np.random.seed(42) np.random.seed(42)
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
import pytest import pytest
import scipy.sparse as ssp import scipy.sparse as ssp
from dgl import DGLError from dgl import DGLError
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
def create_test_heterograph(num_nodes, num_adj, idtype): def create_test_heterograph(num_nodes, num_adj, idtype):
......
...@@ -18,7 +18,7 @@ import backend as F ...@@ -18,7 +18,7 @@ import backend as F
import dgl import dgl
import dgl.partition import dgl.partition
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
@parametrize_idtype @parametrize_idtype
......
...@@ -26,9 +26,9 @@ import dgl.partition ...@@ -26,9 +26,9 @@ import dgl.partition
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
from pytests_utils import parametrize_idtype
from pytests_utils.graph_cases import get_cases
from scipy import sparse as spsp from scipy import sparse as spsp
from utils import parametrize_idtype
from utils.graph_cases import get_cases
D = 5 D = 5
......
...@@ -5,7 +5,7 @@ import backend as F ...@@ -5,7 +5,7 @@ import backend as F
import dgl import dgl
import numpy as np import numpy as np
from dgl.utils import Filter from dgl.utils import Filter
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
def test_graph_filter(): def test_graph_filter():
......
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
import pytest import pytest
import scipy as sp import scipy as sp
from mxnet import autograd, gluon, nd from mxnet import autograd, gluon, nd
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
from pytests_utils.graph_cases import ( from utils.graph_cases import (
get_cases, get_cases,
random_bipartite, random_bipartite,
random_dglgraph, random_dglgraph,
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
@pytest.mark.parametrize("batch_size", [None, 16]) @pytest.mark.parametrize("batch_size", [None, 16])
......
...@@ -8,8 +8,8 @@ import torch as th ...@@ -8,8 +8,8 @@ import torch as th
from dgl import DGLError from dgl import DGLError
from dgl.base import DGLWarning from dgl.base import DGLWarning
from dgl.geometry import farthest_point_sampler, neighbor_matching from dgl.geometry import farthest_point_sampler, neighbor_matching
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
from pytests_utils.graph_cases import get_cases from utils.graph_cases import get_cases
def test_fps(): def test_fps():
......
...@@ -6,7 +6,7 @@ import dgl ...@@ -6,7 +6,7 @@ import dgl
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from pytests_utils import parametrize_idtype from utils import parametrize_idtype
random.seed(42) random.seed(42)
np.random.seed(42) np.random.seed(42)
......
...@@ -12,15 +12,15 @@ import pytest ...@@ -12,15 +12,15 @@ import pytest
import scipy as sp import scipy as sp
import torch import torch
import torch as th import torch as th
from pytests_utils import parametrize_idtype from torch.optim import Adam, SparseAdam
from pytests_utils.graph_cases import ( from torch.utils.data import DataLoader
from utils import parametrize_idtype
from utils.graph_cases import (
get_cases, get_cases,
random_bipartite, random_bipartite,
random_dglgraph, random_dglgraph,
random_graph, random_graph,
) )
from torch.optim import Adam, SparseAdam
from torch.utils.data import DataLoader
tmp_buffer = io.BytesIO() tmp_buffer = io.BytesIO()
......
...@@ -10,14 +10,14 @@ import numpy as np ...@@ -10,14 +10,14 @@ import numpy as np
import pytest import pytest
import scipy as sp import scipy as sp
import tensorflow as tf import tensorflow as tf
from pytests_utils import parametrize_idtype from tensorflow.keras import layers
from pytests_utils.graph_cases import ( from utils import parametrize_idtype
from utils.graph_cases import (
get_cases, get_cases,
random_bipartite, random_bipartite,
random_dglgraph, random_dglgraph,
random_graph, random_graph,
) )
from tensorflow.keras import layers
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
......
...@@ -4,6 +4,13 @@ import dgl ...@@ -4,6 +4,13 @@ import dgl
import pytest import pytest
from dgl.base import is_internal_column from dgl.base import is_internal_column
__all__ = [
"check_fail",
"assert_is_identical",
"assert_is_identical_hetero",
"check_graph_equal",
]
def check_fail(fn, *args, **kwargs): def check_fail(fn, *args, **kwargs):
try: try:
...@@ -66,3 +73,52 @@ def assert_is_identical_hetero(g, g2, ignore_internal_data=False): ...@@ -66,3 +73,52 @@ def assert_is_identical_hetero(g, g2, ignore_internal_data=False):
assert len(g.edges[etype].data) == len(g2.edges[etype].data) assert len(g.edges[etype].data) == len(g2.edges[etype].data)
for k in g.edges[etype].data: for k in g.edges[etype].data:
assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k]) assert F.allclose(g.edges[etype].data[k], g2.edges[etype].data[k])
def check_graph_equal(g1, g2, *, check_idtype=True, check_feature=True):
assert g1.device == g2.device
if check_idtype:
assert g1.idtype == g2.idtype
assert g1.ntypes == g2.ntypes
assert g1.etypes == g2.etypes
assert g1.srctypes == g2.srctypes
assert g1.dsttypes == g2.dsttypes
assert g1.canonical_etypes == g2.canonical_etypes
assert g1.batch_size == g2.batch_size
# check if two metagraphs are identical
for edges, features in g1.metagraph().edges(keys=True).items():
assert g2.metagraph().edges(keys=True)[edges] == features
for nty in g1.ntypes:
assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
assert F.allclose(g1.batch_num_nodes(nty), g2.batch_num_nodes(nty))
for ety in g1.canonical_etypes:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety))
src1, dst1, eid1 = g1.edges(etype=ety, form="all")
src2, dst2, eid2 = g2.edges(etype=ety, form="all")
if check_idtype:
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
assert F.allclose(eid1, eid2)
else:
assert F.allclose(src1, F.astype(src2, g1.idtype))
assert F.allclose(dst1, F.astype(dst2, g1.idtype))
assert F.allclose(eid1, F.astype(eid2, g1.idtype))
if check_feature:
for nty in g1.ntypes:
if g1.number_of_nodes(nty) == 0:
continue
for feat_name in g1.nodes[nty].data.keys():
assert F.allclose(
g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
)
for ety in g1.canonical_etypes:
if g1.number_of_edges(ety) == 0:
continue
for feat_name in g2.edges[ety].data.keys():
assert F.allclose(
g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment