"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "c9bcffd2a53423e6a183e312a58675fb48435d2a"
Unverified Commit a0390dde authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Feature] Add dgl.utils.is_sorted_srcdst() (#2685)



* Add dgl.utils.is_sorted_srcdst

* Fix linting issues

* delete blank line

* Specify datatype to index tensor in test

* Force integer conversion
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 95f8ec83
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import, division from __future__ import absolute_import, division
from ..base import DGLError from ..base import DGLError
from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
def prepare_tensor(g, data, name): def prepare_tensor(g, data, name):
...@@ -166,3 +167,42 @@ def check_valid_idtype(idtype): ...@@ -166,3 +167,42 @@ def check_valid_idtype(idtype):
if idtype not in [None, F.int32, F.int64]: if idtype not in [None, F.int32, F.int64]:
raise DGLError('Expect idtype to be a framework object of int32/int64, ' raise DGLError('Expect idtype to be a framework object of int32/int64, '
'got {}'.format(idtype)) 'got {}'.format(idtype))
def is_sorted_srcdst(src, dst, num_src=None, num_dst=None):
"""Checks whether an edge list is in ascending src-major order (e.g., first
sorted by ``src`` and then by ``dst``).
Parameters
----------
src : IdArray
The tensor of source nodes for each edge.
dst : IdArray
The tensor of destination nodes for each edge.
num_src : int, optional
The number of source nodes.
num_dst : int, optional
The number of destination nodes.
Returns
-------
bool, bool
Whether ``src`` is in ascending order, and whether ``dst`` is
in ascending order with respect to ``src``.
"""
# for some versions of MXNET and TensorFlow, num_src and num_dst get
# incorrectly marked as floats, so force them as integers here
if num_src is None:
num_src = int(F.as_scalar(F.max(src, dim=0)+1))
if num_dst is None:
num_dst = int(F.as_scalar(F.max(dst, dim=0)+1))
src = F.zerocopy_to_dgl_ndarray(src)
dst = F.zerocopy_to_dgl_ndarray(dst)
sorted_status = _CAPI_DGLCOOIsSorted(src, dst, num_src, num_dst)
row_sorted = sorted_status > 0
col_sorted = sorted_status > 1
return row_sorted, col_sorted
_init_api("dgl.utils.checks")
...@@ -6,10 +6,17 @@ ...@@ -6,10 +6,17 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/aten/coo.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <utility>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../array/array_op.h"
using namespace dgl::runtime; using namespace dgl::runtime;
using namespace dgl::aten::impl;
namespace dgl { namespace dgl {
...@@ -19,4 +26,24 @@ DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads") ...@@ -19,4 +26,24 @@ DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
omp_set_num_threads(num_threads); omp_set_num_threads(num_threads);
}); });
DGL_REGISTER_GLOBAL("utils.checks._CAPI_DGLCOOIsSorted")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
IdArray src = args[0];
IdArray dst = args[1];
int64_t num_src = args[2];
int64_t num_dst = args[3];
bool row_sorted, col_sorted;
std::tie(row_sorted, col_sorted) = COOIsSorted(
aten::COOMatrix(num_src, num_dst, src, dst));
// make sure col_sorted is only true when row_sorted is true
assert(!(!row_sorted && col_sorted));
// 0 for unosrted, 1 for row sorted, 2 for row and col sorted
int64_t sorted_status = row_sorted + col_sorted;
*rv = sorted_status;
});
} // namespace dgl } // namespace dgl
...@@ -344,6 +344,31 @@ def test_empty_data_initialized(): ...@@ -344,6 +344,31 @@ def test_empty_data_initialized():
assert "ha" in g.ndata assert "ha" in g.ndata
assert len(g.ndata["ha"]) == 1 assert len(g.ndata["ha"]) == 1
def test_is_sorted():
u_src, u_dst = edge_pair_input(False)
s_src, s_dst = edge_pair_input(True)
u_src = F.tensor(u_src, dtype=F.int32)
u_dst = F.tensor(u_dst, dtype=F.int32)
s_src = F.tensor(s_src, dtype=F.int32)
s_dst = F.tensor(s_dst, dtype=F.int32)
src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)
assert src_sorted == False
assert dst_sorted == False
src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, s_dst)
assert src_sorted == True
assert dst_sorted == True
src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(u_src, u_dst)
assert src_sorted == False
assert dst_sorted == False
src_sorted, dst_sorted = dgl.utils.is_sorted_srcdst(s_src, u_dst)
assert src_sorted == True
assert dst_sorted == False
if __name__ == '__main__': if __name__ == '__main__':
test_query() test_query()
test_mutation() test_mutation()
...@@ -351,3 +376,4 @@ if __name__ == '__main__': ...@@ -351,3 +376,4 @@ if __name__ == '__main__':
test_incmat() test_incmat()
test_find_edges() test_find_edges()
test_hypersparse_query() test_hypersparse_query()
test_is_sorted()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment