"src/vscode:/vscode.git/clone" did not exist on "673eb60f1c4d971e1a577bed767053e50578b461"
Unverified Commit 8a830272 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Merge test utils. (#5440)



* merge

* format

* rename

* sort

* sort

* update

* update

* update

* Update tests/utils/checks.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 5421940a
import dgl
import pytest
import torch
from pytests_utils.graph_cases import get_cases
from utils.graph_cases import get_cases
from dglgo.model import *
......
import backend as F
import dgl
__all__ = ["check_graph_equal"]
def check_graph_equal(g1, g2, *, check_idtype=True, check_feature=True):
assert g1.device == g1.device
if check_idtype:
assert g1.idtype == g2.idtype
assert g1.ntypes == g2.ntypes
assert g1.etypes == g2.etypes
assert g1.srctypes == g2.srctypes
assert g1.dsttypes == g2.dsttypes
assert g1.canonical_etypes == g2.canonical_etypes
assert g1.batch_size == g2.batch_size
# check if two metagraphs are identical
for edges, features in g1.metagraph().edges(keys=True).items():
assert g2.metagraph().edges(keys=True)[edges] == features
for nty in g1.ntypes:
assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
assert F.allclose(g1.batch_num_nodes(nty), g2.batch_num_nodes(nty))
for ety in g1.canonical_etypes:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety))
src1, dst1, eid1 = g1.edges(etype=ety, form="all")
src2, dst2, eid2 = g2.edges(etype=ety, form="all")
if check_idtype:
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
assert F.allclose(eid1, eid2)
else:
assert F.allclose(src1, F.astype(src2, g1.idtype))
assert F.allclose(dst1, F.astype(dst2, g1.idtype))
assert F.allclose(eid1, F.astype(eid2, g1.idtype))
if check_feature:
for nty in g1.ntypes:
if g1.number_of_nodes(nty) == 0:
continue
for feat_name in g1.nodes[nty].data.keys():
assert F.allclose(
g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
)
for ety in g1.canonical_etypes:
if g1.number_of_edges(ety) == 0:
continue
for feat_name in g2.edges[ety].data.keys():
assert F.allclose(
g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
)
......@@ -8,7 +8,7 @@ from dgl.dataloading import (
negative_sampler,
NeighborSampler,
)
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
def create_test_graph(idtype):
......
......@@ -8,7 +8,7 @@ import networkx as nx
import numpy as np
import scipy.sparse as ssp
from dgl import DGLGraph
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
D = 5
reduce_msg_shapes = set()
......
......@@ -10,12 +10,11 @@ import dgl.function as fn
import networkx as nx
import numpy as np
import pytest
import pytests_utils
import scipy.sparse as ssp
from dgl import DGLError
from dgl.ops import edge_softmax
from pytests_utils import get_cases, parametrize_idtype
from scipy.sparse import rand
from utils import get_cases, parametrize_idtype
edge_softmax_shapes = [(1,), (1, 3), (3, 4, 5)]
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
......
......@@ -8,8 +8,8 @@ import numpy as np
import pytest
import torch
from dgl.ops import gather_mm, gsddmm, gspmm, segment_reduce
from pytests_utils import parametrize_idtype
from pytests_utils.graph_cases import get_cases
from utils import parametrize_idtype
from utils.graph_cases import get_cases
random.seed(42)
np.random.seed(42)
......
......@@ -4,7 +4,7 @@ import backend as F
import dgl
import numpy as np
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
def tree1(idtype):
......
......@@ -5,7 +5,7 @@ import backend as F
import dgl
import pytest
from dgl.base import ALL
from pytests_utils import check_graph_equal, get_cases, parametrize_idtype
from utils import check_graph_equal, get_cases, parametrize_idtype
def check_equivalence_between_heterographs(
......
......@@ -7,7 +7,7 @@ import dgl
import dgl.ndarray as nd
import numpy as np
from dgl.frame import Column
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
def test_column_subcolumn():
......
......@@ -10,11 +10,10 @@ import dgl.function as fn
import networkx as nx
import numpy as np
import pytest
import pytests_utils
import scipy.sparse as ssp
from dgl import DGLError
from pytests_utils import get_cases, parametrize_idtype
from scipy.sparse import rand
from utils import get_cases, parametrize_idtype
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
......
......@@ -7,7 +7,7 @@ import dgl.function as fn
import networkx as nx
import numpy as np
import pytest
from pytests_utils import get_cases, parametrize_idtype
from utils import get_cases, parametrize_idtype
def udf_copy_src(edges):
......
......@@ -8,12 +8,16 @@ import dgl
import dgl.function as fn
import networkx as nx
import pytest
import pytests_utils
import scipy.sparse as ssp
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
from pytests_utils import get_cases, parametrize_idtype
from utils import assert_is_identical, assert_is_identical_hetero
from utils import (
assert_is_identical,
assert_is_identical_hetero,
check_graph_equal,
get_cases,
parametrize_idtype,
)
def _assert_is_identical_nodeflow(nf1, nf2):
......@@ -110,7 +114,7 @@ def _global_message_func(nodes):
def test_pickling_graph(g, idtype):
g = g.astype(idtype)
new_g = _reconstruct_pickle(g)
pytests_utils.check_graph_equal(g, new_g, check_feature=True)
check_graph_equal(g, new_g, check_feature=True)
@unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented")
......@@ -142,7 +146,7 @@ def test_pickling_batched_heterograph():
bg = dgl.batch([g, g2])
new_bg = _reconstruct_pickle(bg)
pytests_utils.check_graph_equal(bg, new_bg)
check_graph_equal(bg, new_bg)
@unittest.skipIf(
......
......@@ -2,7 +2,7 @@ import backend as F
import dgl
import numpy as np
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
@parametrize_idtype
......
......@@ -12,7 +12,7 @@ import networkx as nx
import scipy.sparse as ssp
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
def create_test_graph(idtype):
......
......@@ -4,7 +4,7 @@ import dgl
import dgl.function as fn
import numpy as np
import scipy.sparse as sp
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
D = 5
......
......@@ -10,11 +10,10 @@ import dgl.function as fn
import networkx as nx
import numpy as np
import pytest
import pytests_utils
import scipy.sparse as ssp
from dgl import DGLError
from pytests_utils import get_cases, parametrize_idtype
from scipy.sparse import rand
from utils import get_cases, parametrize_idtype
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
feat_size = 2
......
......@@ -10,12 +10,15 @@ import dgl.function as fn
import networkx as nx
import numpy as np
import pytest
import pytests_utils
import scipy.sparse as ssp
from dgl import DGLError
from pytests_utils import get_cases, parametrize_idtype
from scipy.sparse import rand
from utils import assert_is_identical_hetero
from utils import (
assert_is_identical_hetero,
check_graph_equal,
get_cases,
parametrize_idtype,
)
def create_test_heterograph(idtype):
......@@ -2419,7 +2422,7 @@ def test_dtype_cast(idtype):
else:
g_cast = g.int()
assert g_cast.idtype == F.int32
pytests_utils.check_graph_equal(g, g_cast, check_idtype=False)
check_graph_equal(g, g_cast, check_idtype=False)
def test_float_cast():
......
......@@ -4,7 +4,7 @@ import unittest
import backend as F
import dgl
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
@unittest.skipIf(
......
import backend as F
import dgl
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
@parametrize_idtype
......
......@@ -4,7 +4,7 @@ import backend as F
from dgl.distributed import graph_partition_book as gpb
from dgl.partition import NDArrayPartition
from pytests_utils import parametrize_idtype
from utils import parametrize_idtype
@unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment