Unverified Commit 7d8522a2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] fix cumsum on an empty array with prepend_zero returning an empty array (#2179)



* fix cumsum

* udp
Co-authored-by: default avatarZihao <expye@outlook.com>
parent 8a227bfa
......@@ -141,6 +141,8 @@ def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx)
def sum(input, dim, keepdims=False):
if len(input) == 0:
return nd.array([0.], dtype=input.dtype, ctx=input.context)
return nd.sum(input, axis=dim, keepdims=keepdims)
def reduce_sum(input):
......
......@@ -5334,9 +5334,6 @@ class DGLHeteroGraph(object):
>>> g.format()
{'created': ['coo', 'csr', 'csc'], 'not created': []}
"""
if self.num_edges() == 0:
return 0
return self._graph.create_formats_()
def astype(self, idtype):
......
......@@ -14,7 +14,7 @@ template <DLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
return array;
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx);
if (prepend_zero) {
IdArray ret = aten::NewIdArray(len + 1, array->ctx, array->dtype.bits);
const IdType* in_d = array.Ptr<IdType>();
......
......@@ -17,7 +17,8 @@ template <DLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
return array;
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(array->ctx);
const IdType* in_d = array.Ptr<IdType>();
......
......@@ -365,6 +365,40 @@ def test_query(idtype):
# test repr
print(g)
@parametrize_dtype
def test_empty_query(idtype):
g = dgl.graph(([1, 2, 3], [0, 4, 5]), idtype=idtype, device=F.ctx())
g.add_nodes(0)
g.add_edges([], [])
g.remove_edges([])
g.remove_nodes([])
assert F.shape(g.has_nodes([])) == (0,)
assert F.shape(g.has_edges_between([], [])) == (0,)
g.edge_ids([], [])
g.edge_ids([], [], return_uv=True)
g.find_edges([])
assert F.shape(g.in_edges([], form='eid')) == (0,)
u, v = g.in_edges([], form='uv')
assert F.shape(u) == (0,)
assert F.shape(v) == (0,)
u, v, e = g.in_edges([], form='all')
assert F.shape(u) == (0,)
assert F.shape(v) == (0,)
assert F.shape(e) == (0,)
assert F.shape(g.out_edges([], form='eid')) == (0,)
u, v = g.out_edges([], form='uv')
assert F.shape(u) == (0,)
assert F.shape(v) == (0,)
u, v, e = g.out_edges([], form='all')
assert F.shape(u) == (0,)
assert F.shape(v) == (0,)
assert F.shape(e) == (0,)
assert F.shape(g.in_degrees([])) == (0,)
assert F.shape(g.out_degrees([])) == (0,)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU does not have COO impl.")
def _test_hypersparse():
N1 = 1 << 50 # should crash if allocated a CSR
......@@ -2460,6 +2494,7 @@ if __name__ == '__main__':
#test_remove_edges(F.int32)
#test_remove_nodes(F.int32)
#test_clone(F.int32)
test_frame(F.int32)
test_frame_device(F.int32)
#test_frame(F.int32)
#test_frame_device(F.int32)
#test_empty_query(F.int32)
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment