"tests/vscode:/vscode.git/clone" did not exist on "b1eeb934494ef1eee20cf2d35b718790cb9cb550"
Unverified Commit cf829dc3 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Fixing some unittest warnings. (#6122)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 77c84834
......@@ -509,6 +509,7 @@ def backward(x, head_gradient=None):
def grad(x):
x.retain_grad()
return x.grad
......
......@@ -1231,6 +1231,16 @@ def bipartite_from_scipy(
return g.to(device)
def _batcher(lst):
if F.is_tensor(lst[0]):
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
if isinstance(lst[0], np.ndarray):
return F.tensor(np.array(lst))
return F.tensor(lst)
def from_networkx(
nx_graph,
node_attrs=None,
......@@ -1367,12 +1377,6 @@ def from_networkx(
# handle features
# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else:
return F.tensor(lst)
if node_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(list)
......@@ -1592,12 +1596,6 @@ def bipartite_from_networkx(
# handle features
# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else:
return F.tensor(lst)
if u_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
src_attr_dict = defaultdict(list)
......
......@@ -89,7 +89,7 @@ def test_prop_nodes_topo(idtype):
assert check_fail(dgl.prop_nodes_topo, g) # has loop
# tree
tree = dgl.DGLGraph()
tree = dgl.graph([])
tree.add_nodes(5)
tree.add_edges(1, 0)
tree.add_edges(2, 0)
......
......@@ -13,7 +13,7 @@ D = 5
def generate_graph(grad=False, add_data=True):
g = dgl.DGLGraph().to(F.ctx())
g = dgl.graph([]).to(F.ctx())
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
......@@ -111,7 +111,7 @@ def test_subgraph_relabel_nodes(relabel_nodes):
def _test_map_to_subgraph():
g = dgl.DGLGraph()
g = dgl.graph([])
g.add_nodes(10)
g.add_edges(F.arange(0, 9), F.arange(1, 10))
h = g.subgraph([0, 1, 2, 5, 8])
......
......@@ -95,7 +95,7 @@ DFS_LABEL_NAMES = ["forward", "reverse", "nontree"]
@parametrize_idtype
def test_dfs_labeled_edges(idtype, example=False):
dgl_g = dgl.DGLGraph().astype(idtype)
dgl_g = dgl.graph([]).astype(idtype)
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
dgl_edges, dgl_labels = dgl.dfs_labeled_edges_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment