"tests/vscode:/vscode.git/clone" did not exist on "d19887cde5bdbdef724e96842b28287f1bc5fead"
Unverified Commit 13204383 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving basic tests. (#6145)

parent 44f4b0e2
import unittest import warnings
from collections import defaultdict as ddict from collections import defaultdict as ddict
import backend as F import backend as F
...@@ -6,8 +6,6 @@ import backend as F ...@@ -6,8 +6,6 @@ import backend as F
import dgl import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import scipy.sparse as ssp
from dgl import DGLGraph
from utils import parametrize_idtype from utils import parametrize_idtype
D = 5 D = 5
...@@ -33,7 +31,7 @@ def apply_node_func(nodes): ...@@ -33,7 +31,7 @@ def apply_node_func(nodes):
def generate_graph_old(grad=False): def generate_graph_old(grad=False):
g = DGLGraph() g = dgl.graph([])
g.add_nodes(10) # 10 nodes g.add_nodes(10) # 10 nodes
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges # 17 edges
...@@ -419,7 +417,14 @@ def test_update_all_0deg(idtype): ...@@ -419,7 +417,14 @@ def test_update_all_0deg(idtype):
# test#2: graph with no edge # test#2: graph with no edge
g = dgl.graph(([], []), num_nodes=5, idtype=idtype, device=F.ctx()) g = dgl.graph(([], []), num_nodes=5, idtype=idtype, device=F.ctx())
g.ndata["h"] = old_repr g.ndata["h"] = old_repr
g.update_all(_message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2}) # Intercepting the warning: The input graph for the user-defined edge
# function does not contain valid edges.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
g.update_all(
_message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2}
)
new_repr = g.ndata["h"] new_repr = g.ndata["h"]
# should fallback to apply # should fallback to apply
assert F.allclose(new_repr, 2 * old_repr) assert F.allclose(new_repr, 2 * old_repr)
...@@ -455,7 +460,12 @@ def test_pull_0deg(idtype): ...@@ -455,7 +460,12 @@ def test_pull_0deg(idtype):
# test#2: pull only 0deg node # test#2: pull only 0deg node
old = F.randn((2, 5)) old = F.randn((2, 5))
g.ndata["h"] = old g.ndata["h"] = old
# Intercepting the warning: The input graph for the user-defined edge
# function does not contain valid edges
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
g.pull(0, _message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2}) g.pull(0, _message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2})
new = g.ndata["h"] new = g.ndata["h"]
# 0deg check: fallback to apply # 0deg check: fallback to apply
assert F.allclose(new[0], 2 * old[0]) assert F.allclose(new[0], 2 * old[0])
...@@ -467,8 +477,7 @@ def test_dynamic_addition(): ...@@ -467,8 +477,7 @@ def test_dynamic_addition():
N = 3 N = 3
D = 1 D = 1
g = DGLGraph() g = dgl.graph([]).to(F.ctx())
g = g.to(F.ctx())
# Test node addition # Test node addition
g.add_nodes(N) g.add_nodes(N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment