Unverified Commit 2ce426d9 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] add_edges() crashes if the input tensor is empty (#2100)

* [Bug] add_edges() crashes if the input tensor is empty

* lint

* fix
parent 5bc029b9
...@@ -49,6 +49,8 @@ def tensor(data, dtype=None): ...@@ -49,6 +49,8 @@ def tensor(data, dtype=None):
if dtype is None: if dtype is None:
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
dtype = np.int32 if data.dtype == np.bool else data.dtype dtype = np.int32 if data.dtype == np.bool else data.dtype
elif len(data) == 0:
dtype = np.int64
else: else:
dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32 dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32
return nd.array(data, dtype=dtype) return nd.array(data, dtype=dtype)
......
...@@ -35,7 +35,8 @@ def prepare_tensor(g, data, name): ...@@ -35,7 +35,8 @@ def prepare_tensor(g, data, name):
ret = data ret = data
else: else:
data = F.tensor(data) data = F.tensor(data)
if F.dtype(data) not in (F.int32, F.int64): if (not (F.ndim(data) > 0 and F.shape(data)[0] == 0) and # empty tensor
F.dtype(data) not in (F.int32, F.int64)):
raise DGLError('Expect argument "{}" to have data type int32 or int64,' raise DGLError('Expect argument "{}" to have data type int32 or int64,'
' but got {}.'.format(name, F.dtype(data))) ' but got {}.'.format(name, F.dtype(data)))
ret = F.copy_to(F.astype(data, g.idtype), g.device) ret = F.copy_to(F.astype(data, g.idtype), g.device)
......
...@@ -1004,6 +1004,15 @@ def test_add_edges(idtype): ...@@ -1004,6 +1004,15 @@ def test_add_edges(idtype):
u, v = g.edges(form='uv', order='eid') u, v = g.edges(form='uv', order='eid')
assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype)) assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))
assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype)) assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))
g = dgl.add_edges(g, [], [])
g = dgl.add_edges(g, 0, [])
g = dgl.add_edges(g, [], 0)
assert g.device == F.ctx()
assert g.number_of_nodes() == 3
assert g.number_of_edges() == 5
u, v = g.edges(form='uv', order='eid')
assert F.array_equal(u, F.tensor([0, 1, 0, 0, 0], dtype=idtype))
assert F.array_equal(v, F.tensor([1, 2, 1, 1, 1], dtype=idtype))
# node id larger than current max node id # node id larger than current max node id
g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx()) g = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment