Unverified Commit c49582c9 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Test] Fix the random seed in test_kernel.py and fix the _print_error (#1194)



* Fix the random seed and fix the _print_error

* Fix check_positive_edge_sampler sample weight.
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent 64f6f3c1
...@@ -5,8 +5,6 @@ import numpy as np ...@@ -5,8 +5,6 @@ import numpy as np
import backend as F import backend as F
from itertools import product from itertools import product
np.random.seed(31)
def udf_copy_src(edges): def udf_copy_src(edges):
return {'m': edges.src['u']} return {'m': edges.src['u']}
...@@ -36,6 +34,7 @@ def generate_feature(g, broadcast='none', binary_op='none'): ...@@ -36,6 +34,7 @@ def generate_feature(g, broadcast='none', binary_op='none'):
"""Create graph with src, edge, dst feature. broadcast can be 'u', """Create graph with src, edge, dst feature. broadcast can be 'u',
'e', 'v', 'none' 'e', 'v', 'none'
""" """
np.random.seed(31)
nv = g.number_of_nodes() nv = g.number_of_nodes()
ne = g.number_of_edges() ne = g.number_of_edges()
if binary_op == 'dot': if binary_op == 'dot':
...@@ -303,8 +302,21 @@ def test_all_binary_builtins(): ...@@ -303,8 +302,21 @@ def test_all_binary_builtins():
def _print_error(a, b): def _print_error(a, b):
print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}". print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".
format(lhs, binary_op, rhs, reducer, broadcast, partial)) format(lhs, binary_op, rhs, reducer, broadcast, partial))
print("lhs", lhs) if lhs == 'u':
print("rhs", rhs) lhs_data = hu
elif lhs == 'v':
lhs_data = hv
elif lhs == 'e':
lhs_data = he
if rhs == 'u':
rhs_data = hu
elif rhs == 'v':
rhs_data = hv
elif rhs == 'e':
rhs_data = he
print("lhs", F.asnumpy(lhs_data).tolist())
print("rhs", F.asnumpy(rhs_data).tolist())
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y, rtol, atol): if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y)) print('@{} {} v.s. {}'.format(i, x, y))
......
...@@ -668,9 +668,9 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size): ...@@ -668,9 +668,9 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
def check_positive_edge_sampler(): def check_positive_edge_sampler():
g = generate_rand_graph(1000) g = generate_rand_graph(1000)
num_edges = g.number_of_edges() num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu()) edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 0.1, dtype=np.float32)), F.cpu())
edge_weight[num_edges-1] = num_edges ** 3 edge_weight[num_edges-1] = num_edges ** 2
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler') EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check # Correctness check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment