Unverified Commit c49582c9 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Test] Fix the random seed in test_kernel.py and fix the _print_error (#1194)



* Fix the random seed and fix the _print_error

* Fix check_positive_edge_sampler sample weight.
Co-authored-by: default avatarZihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: default avatarVoVAllen <VoVAllen@users.noreply.github.com>
parent 64f6f3c1
......@@ -5,8 +5,6 @@ import numpy as np
import backend as F
from itertools import product
np.random.seed(31)
def udf_copy_src(edges):
return {'m': edges.src['u']}
......@@ -36,6 +34,7 @@ def generate_feature(g, broadcast='none', binary_op='none'):
"""Create graph with src, edge, dst feature. broadcast can be 'u',
'e', 'v', 'none'
"""
np.random.seed(31)
nv = g.number_of_nodes()
ne = g.number_of_edges()
if binary_op == 'dot':
......@@ -303,8 +302,21 @@ def test_all_binary_builtins():
def _print_error(a, b):
print("ERROR: Test {}_{}_{}_{} broadcast: {} partial: {}".
format(lhs, binary_op, rhs, reducer, broadcast, partial))
print("lhs", lhs)
print("rhs", rhs)
if lhs == 'u':
lhs_data = hu
elif lhs == 'v':
lhs_data = hv
elif lhs == 'e':
lhs_data = he
if rhs == 'u':
rhs_data = hu
elif rhs == 'v':
rhs_data = hv
elif rhs == 'e':
rhs_data = he
print("lhs", F.asnumpy(lhs_data).tolist())
print("rhs", F.asnumpy(rhs_data).tolist())
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y))
......
......@@ -668,9 +668,9 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
def check_positive_edge_sampler():
g = generate_rand_graph(1000)
num_edges = g.number_of_edges()
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 1, dtype=np.float32)), F.cpu())
edge_weight = F.copy_to(F.tensor(np.full((num_edges,), 0.1, dtype=np.float32)), F.cpu())
edge_weight[num_edges-1] = num_edges ** 3
edge_weight[num_edges-1] = num_edges ** 2
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
# Correctness check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment