Unverified Commit 57daf9c9 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Graph] small fix on the all index (#219)

parent 57b07fce
...@@ -592,7 +592,9 @@ class DGLGraph(object): ...@@ -592,7 +592,9 @@ class DGLGraph(object):
-------- --------
in_degree in_degree
""" """
if not is_all(v): if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor() return self._graph.in_degrees(v).tousertensor()
...@@ -632,7 +634,9 @@ class DGLGraph(object): ...@@ -632,7 +634,9 @@ class DGLGraph(object):
-------- --------
out_degree out_degree
""" """
if not is_all(v): if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v) v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor() return self._graph.out_degrees(v).tousertensor()
......
...@@ -395,9 +395,6 @@ class GraphIndex(object): ...@@ -395,9 +395,6 @@ class GraphIndex(object):
int int
The in degree array. The in degree array.
""" """
if is_all(v):
v = np.arange(0, self.number_of_nodes(), dtype=np.int64)
v = utils.toindex(v)
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array)) return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array))
...@@ -429,9 +426,6 @@ class GraphIndex(object): ...@@ -429,9 +426,6 @@ class GraphIndex(object):
int int
The out degree array. The out degree array.
""" """
if is_all(v):
v = np.arange(0, self.number_of_nodes(), dtype=np.int64)
v = utils.toindex(v)
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array)) return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
......
...@@ -365,7 +365,7 @@ class ImmutableGraphIndex(object): ...@@ -365,7 +365,7 @@ class ImmutableGraphIndex(object):
The in degree array. The in degree array.
""" """
deg = self._get_in_degree() deg = self._get_in_degree()
if is_all(v): if v.is_slice(0, self.number_of_nodes()):
return utils.toindex(deg) return utils.toindex(deg)
else: else:
v_array = v.tousertensor() v_array = v.tousertensor()
...@@ -401,7 +401,7 @@ class ImmutableGraphIndex(object): ...@@ -401,7 +401,7 @@ class ImmutableGraphIndex(object):
The out degree array. The out degree array.
""" """
deg = self._get_out_degree() deg = self._get_out_degree()
if is_all(v): if v.is_slice(0, self.number_of_nodes()):
return utils.toindex(deg) return utils.toindex(deg)
else: else:
v_array = v.tousertensor() v_array = v.tousertensor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment