Unverified Commit 3e43d7b8 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Test] Lower allclose precision (#129)

parent bd0e4fa0
......@@ -2,6 +2,7 @@ import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph
import utils as U
D = 5
reduce_msg_shapes = set()
......@@ -49,7 +50,7 @@ def test_batch_setter_getter():
g = generate_graph()
# set all nodes
g.ndata['h'] = th.zeros((10, D))
assert th.allclose(g.ndata['h'], th.zeros((10, D)))
assert U.allclose(g.ndata['h'], th.zeros((10, D)))
# pop nodes
old_len = len(g.ndata)
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
......@@ -175,12 +176,12 @@ def test_apply_edges():
g.register_apply_edge_func(_upd)
old = g.edata['w']
g.apply_edges()
assert th.allclose(old * 2, g.edata['w'])
assert U.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
assert U.allclose(g.edata['w'][eid], th.zeros((6, D)))
def test_update_routines():
g = generate_graph()
......@@ -232,8 +233,8 @@ def test_reduce_0deg():
g.update_all(_message, _reduce)
new_repr = g.ndata['h']
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0), rtol=1e-3, atol=1e-3)
assert U.allclose(new_repr[1:], old_repr[1:])
assert U.allclose(new_repr[0], old_repr.sum(0))
def test_pull_0deg():
g = DGLGraph()
......@@ -248,19 +249,19 @@ def test_pull_0deg():
g.pull(0, _message, _reduce)
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
assert U.allclose(new_repr[0], old_repr[0])
assert U.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce)
new_repr = g.ndata['h']
assert th.allclose(new_repr[1], old_repr[0])
assert U.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce)
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
assert U.allclose(new_repr[0], old_repr[0])
assert U.allclose(new_repr[1], old_repr[0])
def _disabled_test_send_twice():
# TODO(minjie): please re-enable this unittest after the send code problem is fixed.
......@@ -281,14 +282,14 @@ def _disabled_test_send_twice():
g.send((0, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr[0] * 3)
assert U.allclose(new_repr[1], old_repr[0] * 3)
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((2, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph():
g = DGLGraph(multigraph=True)
......@@ -315,14 +316,14 @@ def test_send_multigraph():
g.send([0, 2], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send([0, 2, 3], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph
g.ndata['a'] = th.zeros(3, 5)
......@@ -330,7 +331,7 @@ def test_send_multigraph():
g.send(([0, 2], [1, 1]), _message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0])
assert U.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on
g.ndata['a'] = th.zeros(3, 5)
......@@ -339,7 +340,7 @@ def test_send_multigraph():
g.send([0, 1], message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
assert U.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on
g.ndata['a'] = th.zeros(3, 5)
......@@ -348,15 +349,15 @@ def test_send_multigraph():
g.send(1, message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
assert U.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert U.allclose(new_repr[[0, 2]], th.zeros(2, 5))
def test_dynamic_addition():
N = 3
......
......@@ -2,6 +2,7 @@ import networkx as nx
import dgl
import torch as th
import numpy as np
import utils as U
def tree1():
"""Generate a tree
......@@ -57,10 +58,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.ndata['h'], tt1.ndata['h'])
assert th.allclose(t1.edata['h'], tt1.edata['h'])
assert th.allclose(t2.ndata['h'], tt2.ndata['h'])
assert th.allclose(t2.edata['h'], tt2.edata['h'])
assert U.allclose(t1.ndata['h'], tt1.ndata['h'])
assert U.allclose(t1.edata['h'], tt1.edata['h'])
assert U.allclose(t2.ndata['h'], tt2.ndata['h'])
assert U.allclose(t2.edata['h'], tt2.edata['h'])
def test_batch_unbatch1():
t1 = tree1()
......@@ -74,12 +75,12 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.ndata['h'], s1.ndata['h'])
assert th.allclose(t2.edata['h'], s1.edata['h'])
assert th.allclose(t1.ndata['h'], s2.ndata['h'])
assert th.allclose(t1.edata['h'], s2.edata['h'])
assert th.allclose(t2.ndata['h'], s3.ndata['h'])
assert th.allclose(t2.edata['h'], s3.edata['h'])
assert U.allclose(t2.ndata['h'], s1.ndata['h'])
assert U.allclose(t2.edata['h'], s1.edata['h'])
assert U.allclose(t1.ndata['h'], s2.ndata['h'])
assert U.allclose(t1.edata['h'], s2.edata['h'])
assert U.allclose(t2.ndata['h'], s3.ndata['h'])
assert U.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_sendrecv():
t1 = tree1()
......
import torch as th
import numpy as np
from dgl.graph import DGLGraph
import utils as U
def test_filter():
g = DGLGraph()
......
......@@ -3,6 +3,7 @@ from torch.autograd import Variable
import numpy as np
from dgl.frame import Frame, FrameRef
from dgl.utils import Index, toindex
import utils as U
N = 10
D = 5
......@@ -43,9 +44,9 @@ def test_column1():
f = Frame(data)
assert f.num_rows == N
assert len(f) == 3
assert th.allclose(f['a1'].data, data['a1'].data)
assert U.allclose(f['a1'].data, data['a1'].data)
f['a1'] = data['a2']
assert th.allclose(f['a2'].data, data['a2'].data)
assert U.allclose(f['a2'].data, data['a2'].data)
# add a different length column should fail
def failed_add_col():
f['a4'] = th.zeros([N+1, D])
......@@ -68,10 +69,10 @@ def test_column2():
f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5
assert len(f) == 3
assert th.allclose(f['a1'], data['a1'].data[3:8])
assert U.allclose(f['a1'], data['a1'].data[3:8])
# set column should reflect on the referenced data
f['a1'] = th.zeros([5, D])
assert th.allclose(data['a1'].data[3:8], th.zeros([5, D]))
assert U.allclose(data['a1'].data[3:8], th.zeros([5, D]))
# add new partial column should fail with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
def failed_add_col():
......@@ -90,7 +91,7 @@ def test_append1():
c1 = f1['a1']
assert c1.data.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
assert th.allclose(truth, c1.data)
assert U.allclose(truth, c1.data)
# append dict of different length columns should fail
f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))}
def failed_append():
......@@ -129,13 +130,13 @@ def test_row1():
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert th.allclose(v, data[k][rowid])
assert U.allclose(v, data[k][rowid])
# test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert th.allclose(v, data[k][rowid])
assert U.allclose(v, data[k][rowid])
# setter
rowid = Index(th.tensor([0, 2, 4]))
......@@ -145,7 +146,7 @@ def test_row1():
}
f[rowid] = vals
for k, v in f[rowid].items():
assert th.allclose(v, th.zeros((len(rowid), D)))
assert U.allclose(v, th.zeros((len(rowid), D)))
# setting rows with new column should raise error with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
......@@ -165,13 +166,13 @@ def test_row2():
rowid = Index(th.tensor([0, 2]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert th.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
assert U.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_()
# test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
assert U.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_()
# setter
......@@ -184,8 +185,8 @@ def test_row2():
f[rowid] = vals
c11 = f['a1']
c11.backward(th.ones((N, D)))
assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert th.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
assert U.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert U.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None
def test_row3():
......@@ -207,7 +208,7 @@ def test_row3():
newidx.pop(2)
newidx = toindex(newidx)
for k, v in f.items():
assert th.allclose(v, data[k][newidx])
assert U.allclose(v, data[k][newidx])
def test_sharing():
data = Frame(create_test_data())
......@@ -215,9 +216,9 @@ def test_sharing():
f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read
for k, v in f1.items():
assert th.allclose(data[k].data[0:4], v)
assert U.allclose(data[k].data[0:4], v)
for k, v in f2.items():
assert th.allclose(data[k].data[2:7], v)
assert U.allclose(data[k].data[2:7], v)
f2_a1 = f2['a1'].data
# test write
# update own ref should not been seen by the other.
......@@ -226,7 +227,7 @@ def test_sharing():
'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]),
}
assert th.allclose(f2['a1'], f2_a1)
assert U.allclose(f2['a1'], f2_a1)
# update shared space should been seen by the other.
f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]),
......@@ -234,7 +235,7 @@ def test_sharing():
'a3' : th.ones([2, D]),
}
f2_a1[0:2] = th.ones([2, D])
assert th.allclose(f2['a1'], f2_a1)
assert U.allclose(f2['a1'], f2_a1)
def test_slicing():
data = Frame(create_test_data(grad=True))
......@@ -242,7 +243,7 @@ def test_slicing():
f2 = FrameRef(data, index=slice(3, 8))
# test read
for k, v in f1.items():
assert th.allclose(data[k].data[1:5], v)
assert U.allclose(data[k].data[1:5], v)
f2_a1 = f2['a1'].data
# test write
f1[Index(th.tensor([0, 1]))] = {
......@@ -250,7 +251,7 @@ def test_slicing():
'a2': th.zeros([2, D]),
'a3': th.zeros([2, D]),
}
assert th.allclose(f2['a1'], f2_a1)
assert U.allclose(f2['a1'], f2_a1)
f1[Index(th.tensor([2, 3]))] = {
'a1': th.ones([2, D]),
......@@ -258,7 +259,7 @@ def test_slicing():
'a3': th.ones([2, D]),
}
f2_a1[0:2] = 1
assert th.allclose(f2['a1'], f2_a1)
assert U.allclose(f2['a1'], f2_a1)
f1[2:4] = {
'a1': th.zeros([2, D]),
......@@ -266,7 +267,7 @@ def test_slicing():
'a3': th.zeros([2, D]),
}
f2_a1[0:2] = 0
assert th.allclose(f2['a1'], f2_a1)
assert U.allclose(f2['a1'], f2_a1)
if __name__ == '__main__':
test_create()
......
import torch as th
import dgl
import dgl.function as fn
import utils as U
def generate_graph():
g = dgl.DGLGraph()
......@@ -27,7 +28,7 @@ def test_copy_src():
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.ndata['h'],
assert U.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge():
......@@ -36,7 +37,7 @@ def test_copy_edge():
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.ndata['h'],
assert U.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge():
......@@ -45,7 +46,7 @@ def test_src_mul_edge():
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.ndata['h'],
assert U.allclose(g.ndata['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__':
......
......@@ -4,6 +4,7 @@ from dgl.utils import toindex
import numpy as np
import torch as th
from torch.utils import dlpack
import utils as U
def test_dlpack():
# test dlpack conversion.
......
......@@ -2,6 +2,7 @@ import torch as th
import networkx as nx
import numpy as np
import dgl
import utils as U
D = 5
......@@ -19,13 +20,13 @@ def test_line_graph():
v = [1, 2, 0, 0]
eid = G.edge_ids(u, v)
L.nodes[eid].data['h'] = th.zeros((4, D))
assert th.allclose(G.edges[u, v].data['h'], th.zeros((4, D)))
assert U.allclose(G.edges[u, v].data['h'], th.zeros((4, D)))
# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
data = th.randn(n_edges, D)
L.ndata['w'] = data
assert th.allclose(G.edata['w'], data)
assert U.allclose(G.edata['w'], data)
def test_no_backtracking():
N = 5
......
......@@ -2,6 +2,7 @@ import torch as th
import numpy as np
import dgl
import dgl.function as fn
import utils as U
D = 5
......@@ -43,7 +44,7 @@ def test_update_all():
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert th.allclose(v2, v3)
assert U.allclose(v2, v3)
# update all with edge weights
v1 = g.ndata[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
......@@ -56,8 +57,8 @@ def test_update_all():
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
assert U.allclose(v2, v3)
assert U.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
......@@ -90,7 +91,7 @@ def test_send_and_recv():
g.set_n_repr({fld : v1})
g.send_and_recv((u, v), message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert th.allclose(v2, v3)
assert U.allclose(v2, v3)
# send and recv with edge weights
v1 = g.ndata[fld]
g.send_and_recv((u, v), fn.src_mul_edge(src=fld, edge='e1', out='m'),
......@@ -103,8 +104,8 @@ def test_send_and_recv():
g.set_n_repr({fld : v1})
g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
assert U.allclose(v2, v3)
assert U.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
......@@ -129,19 +130,19 @@ def test_update_all_multi_fn():
None)
v1 = g.ndata['v1']
v2 = g.ndata['v2']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'), None)
v1 = g.ndata['v1']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
# 1 message, 2 reduces
g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')], None)
v2 = g.ndata['v2']
v3 = g.ndata['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
assert U.allclose(v1, v2)
assert U.allclose(v1, v3)
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
......@@ -150,13 +151,13 @@ def test_update_all_multi_fn():
v1 = g.ndata['v1']
v2 = g.ndata['v2']
v3 = g.ndata['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
assert U.allclose(v1, v2)
assert U.allclose(v1, v3)
# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None)
v2 = g.ndata['v2']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
def test_send_and_recv_multi_fn():
u = th.tensor([0, 0, 0, 3, 4, 9])
......@@ -183,13 +184,13 @@ def test_send_and_recv_multi_fn():
None)
v1 = g.ndata['v1']
v2 = g.ndata['v2']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
# run builtin with single message and reduce
g.send_and_recv((u, v), fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'),
None)
v1 = g.ndata['v1']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
# 1 message, 2 reduces
g.send_and_recv((u, v),
......@@ -198,8 +199,8 @@ def test_send_and_recv_multi_fn():
None)
v2 = g.ndata['v2']
v3 = g.ndata['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
assert U.allclose(v1, v2)
assert U.allclose(v1, v3)
# send and recv with edge weights, 2 message, 3 reduces
g.send_and_recv((u, v),
......@@ -209,14 +210,14 @@ def test_send_and_recv_multi_fn():
v1 = g.ndata['v1']
v2 = g.ndata['v2']
v3 = g.ndata['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
assert U.allclose(v1, v2)
assert U.allclose(v1, v3)
# run UDF with single message and reduce
g.send_and_recv((u, v), message_func_edge,
reduce_func, None)
v2 = g.ndata['v2']
assert th.allclose(v1, v2)
assert U.allclose(v1, v2)
if __name__ == '__main__':
test_update_all()
......
......@@ -2,6 +2,7 @@ import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph
import utils as U
D = 5
......@@ -37,7 +38,7 @@ def test_basics():
assert len(sg.ndata) == 1
assert len(sg.edata) == 1
sh = sg.ndata['h']
assert th.allclose(h[nid], sh)
assert U.allclose(h[nid], sh)
'''
s, d, eid
0, 1, 0
......@@ -58,11 +59,11 @@ def test_basics():
8, 9, 15 3
9, 0, 16 1
'''
assert th.allclose(l[eid], sg.edata['l'])
assert U.allclose(l[eid], sg.edata['l'])
# update the node/edge features on the subgraph should NOT
# reflect to the parent graph.
sg.ndata['h'] = th.zeros((6, D))
assert th.allclose(h, g.ndata['h'])
assert U.allclose(h, g.ndata['h'])
def test_merge():
# FIXME: current impl cannot handle this case!!!
......@@ -87,8 +88,8 @@ def test_merge():
h = g.ndata['h'][:,0]
l = g.edata['l'][:,0]
assert th.allclose(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
assert th.allclose(l,
assert U.allclose(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
assert U.allclose(l,
th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.]))
"""
......
......@@ -7,6 +7,7 @@ import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch as th
import utils as U
np.random.seed(42)
......
import torch as th
def allclose(a, b):
return th.allclose(a, b, rtol=1e-4, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment