Unverified Commit 6a6597a0 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Feature] extend sort_csr/csc_by_tag to edge (#4164)



* [Feature] extend sort_csr/csc_by_tag to edge

* fix test ffailure in tensorflow

* refine sorting by edges

* fix docstring

* remove unnecessary mem
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent b76d0ed1
...@@ -2802,7 +2802,7 @@ def as_immutable_graph(hg): ...@@ -2802,7 +2802,7 @@ def as_immutable_graph(hg):
'\tdgl.as_immutable_graph will do nothing and can be removed safely in all cases.') '\tdgl.as_immutable_graph will do nothing and can be removed safely in all cases.')
return hg return hg
def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'):
r"""Return a new graph whose CSR matrix is sorted by the given tag. r"""Return a new graph whose CSR matrix is sorted by the given tag.
Sort the internal CSR matrix of the graph so that the adjacency list of each node Sort the internal CSR matrix of the graph so that the adjacency list of each node
...@@ -2821,6 +2821,9 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2821,6 +2821,9 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
0 -> 2, 4, 0, 1, 3 0 -> 2, 4, 0, 1, 3
1 -> 2, 0, 1 1 -> 2, 0, 1
Given edge tags ``[1, 1, 0, 2, 0, 1, 1, 0]`` has the same effect
as above node tags.
The function will also returns the starting offsets of the tag The function will also returns the starting offsets of the tag
segments in a tensor of shape :math:`(N, max\_tag+2)`. For node ``i``, segments in a tensor of shape :math:`(N, max\_tag+2)`. For node ``i``,
its out-edges connecting to node tag ``j`` is stored between its out-edges connecting to node tag ``j`` is stored between
...@@ -2847,9 +2850,12 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2847,9 +2850,12 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
g : DGLGraph g : DGLGraph
The input graph. The input graph.
tag : Tensor tag : Tensor
Integer tensor of shape :math:`(N,)`, :math:`N` being the number of (destination) nodes. Integer tensor of shape :math:`(N,)`, :math:`N` being the number
of (destination) nodes or edges.
tag_offset_name : str tag_offset_name : str
The name of the node feature to store tag offsets. The name of the node feature to store tag offsets.
tag_type : str
Tag type which could be ``node`` or ``edge``.
Returns Returns
------- -------
...@@ -2863,6 +2869,9 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2863,6 +2869,9 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
Examples Examples
----------- -----------
``tag_type`` is ``node``.
>>> import dgl
>>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2])) >>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2]))
>>> g.adjacency_matrix(scipy_fmt='csr').nonzero() >>> g.adjacency_matrix(scipy_fmt='csr').nonzero()
(array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),
...@@ -2879,12 +2888,34 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2879,12 +2888,34 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]]) [0, 0, 0, 0]])
``tag_type`` is ``edge``.
>>> from dgl import backend as F
>>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2]))
>>> g.edges()
(tensor([0, 0, 0, 0, 0, 1, 1, 1]), tensor([0, 1, 2, 3, 4, 0, 1, 2]))
>>> tag = F.tensor([1, 1, 0, 2, 0, 1, 1, 0])
>>> g_sorted = dgl.sort_csr_by_tag(g, tag, tag_type='edge')
>>> g_sorted.adj(scipy_fmt='csr').nonzero()
(array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))
>>> g_sorted.srcdata['_TAG_OFFSET']
tensor([[0, 2, 4, 5],
[0, 1, 3, 3],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
See Also See Also
-------- --------
dgl.sampling.sample_neighbors_biased dgl.sampling.sample_neighbors_biased
""" """
if len(g.etypes) > 1: if len(g.etypes) > 1:
raise DGLError("Only support homograph and bipartite graph") raise DGLError("Only support homograph and bipartite graph")
assert tag_type in ['node', 'edge'], "tag_type should be either 'node' or 'edge'."
if tag_type == 'node':
_, dst = g.edges()
tag = F.gather_row(tag, F.tensor(dst))
assert len(tag) == g.num_edges()
num_tags = int(F.asnumpy(F.max(tag, 0))) + 1 num_tags = int(F.asnumpy(F.max(tag, 0))) + 1
tag_arr = F.zerocopy_to_dgl_ndarray(tag) tag_arr = F.zerocopy_to_dgl_ndarray(tag)
new_g = g.clone() new_g = g.clone()
...@@ -2893,7 +2924,7 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2893,7 +2924,7 @@ def sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
return new_g return new_g
def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET', tag_type='node'):
r"""Return a new graph whose CSC matrix is sorted by the given tag. r"""Return a new graph whose CSC matrix is sorted by the given tag.
Sort the internal CSC matrix of the graph so that the adjacency list of each node Sort the internal CSC matrix of the graph so that the adjacency list of each node
...@@ -2913,6 +2944,9 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2913,6 +2944,9 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
0 <- 2, 4, 0, 1, 3 0 <- 2, 4, 0, 1, 3
1 <- 2, 0, 1 1 <- 2, 0, 1
Given edge tags ``[1, 1, 0, 2, 0, 1, 1, 0]`` has the same effect
as above node tags.
The function will also return the starting offsets of the tag The function will also return the starting offsets of the tag
segments in a tensor of shape :math:`(N, max\_tag+2)`. For a node ``i``, segments in a tensor of shape :math:`(N, max\_tag+2)`. For a node ``i``,
its in-edges connecting to node tag ``j`` is stored between its in-edges connecting to node tag ``j`` is stored between
...@@ -2939,9 +2973,12 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2939,9 +2973,12 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
g : DGLGraph g : DGLGraph
The input graph. The input graph.
tag : Tensor tag : Tensor
Integer tensor of shape :math:`(N,)`, :math:`N` being the number of (source) nodes. Integer tensor of shape :math:`(N,)`, :math:`N` being the number
of (source) nodes or edges.
tag_offset_name : str tag_offset_name : str
The name of the node feature to store tag offsets. The name of the node feature to store tag offsets.
tag_type : str
Tag type which could be ``node`` or ``edge``.
Returns Returns
------- -------
...@@ -2955,6 +2992,9 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2955,6 +2992,9 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
Examples Examples
----------- -----------
``tag_type`` is ``node``.
>>> import dgl
>>> g = dgl.graph(([0,1,2,3,4,0,1,2],[0,0,0,0,0,1,1,1])) >>> g = dgl.graph(([0,1,2,3,4,0,1,2],[0,0,0,0,0,1,1,1]))
>>> g.adjacency_matrix(scipy_fmt='csr', transpose=True).nonzero() >>> g.adjacency_matrix(scipy_fmt='csr', transpose=True).nonzero()
(array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), (array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),
...@@ -2971,12 +3011,32 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2971,12 +3011,32 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]]) [0, 0, 0, 0]])
``tag_type`` is ``edge``.
>>> from dgl import backend as F
>>> g = dgl.graph(([0,1,2,3,4,0,1,2],[0,0,0,0,0,1,1,1]))
>>> tag = F.tensor([1, 1, 0, 2, 0, 1, 1, 0])
>>> g_sorted = dgl.sort_csc_by_tag(g, tag, tag_type='edge')
>>> g_sorted.adj(scipy_fmt='csr', transpose=True).nonzero()
(array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32), array([2, 4, 0, 1, 3, 2, 0, 1], dtype=int32))
>>> g_sorted.dstdata['_TAG_OFFSET']
tensor([[0, 2, 4, 5],
[0, 1, 3, 3],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
See Also See Also
-------- --------
dgl.sampling.sample_neighbors_biased dgl.sampling.sample_neighbors_biased
""" """
if len(g.etypes) > 1: if len(g.etypes) > 1:
raise DGLError("Only support homograph and bipartite graph") raise DGLError("Only support homograph and bipartite graph")
assert tag_type in ['node', 'edge'], "tag_type should be either 'node' or 'edge'."
if tag_type == 'node':
src, _ = g.edges()
tag = F.gather_row(tag, F.tensor(src))
assert len(tag) == g.num_edges()
num_tags = int(F.asnumpy(F.max(tag, 0))) + 1 num_tags = int(F.asnumpy(F.max(tag, 0))) + 1
tag_arr = F.zerocopy_to_dgl_ndarray(tag) tag_arr = F.zerocopy_to_dgl_ndarray(tag)
new_g = g.clone() new_g = g.clone()
......
...@@ -516,8 +516,8 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -516,8 +516,8 @@ void CSRSort_(CSRMatrix* csr) {
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag, int64_t num_tags) { const CSRMatrix &csr, IdArray tag, int64_t num_tags) {
CHECK_EQ(csr.num_cols, tag->shape[0]) CHECK_EQ(csr.indices->shape[0], tag->shape[0])
<< "The length of the tag array should be equal to the number of columns "; << "The length of the tag array should be equal to the number of non-zero data.";
CHECK_SAME_CONTEXT(csr.indices, tag); CHECK_SAME_CONTEXT(csr.indices, tag);
CHECK_INT(tag, "tag"); CHECK_INT(tag, "tag");
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
......
...@@ -87,9 +87,9 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -87,9 +87,9 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) { const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
const auto indptr_data = static_cast<const IdType *>(csr.indptr->data); const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
const auto indices_data = static_cast<const IdType *>(csr.indices->data); const auto indices_data = static_cast<const IdType *>(csr.indices->data);
const auto eid_array = aten::CSRHasData(csr) ? csr.data : const auto eid_data = aten::CSRHasData(csr)
aten::Range(0, csr.indices->shape[0], csr.indptr->dtype.bits, csr.indptr->ctx); ? static_cast<const IdType *>(csr.data->data)
const auto eid_data = static_cast<const IdType *>(csr.data->data); : nullptr;
const auto tag_data = static_cast<const TagType *>(tag_array->data); const auto tag_data = static_cast<const TagType *>(tag_array->data);
const int64_t num_rows = csr.num_rows; const int64_t num_rows = csr.num_rows;
...@@ -98,9 +98,11 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -98,9 +98,11 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
auto tag_pos_data = static_cast<IdType *>(tag_pos->data); auto tag_pos_data = static_cast<IdType *>(tag_pos->data);
std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0); std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);
aten::CSRMatrix output(csr.num_rows, csr.num_cols, aten::CSRMatrix output(csr.num_rows, csr.num_cols, csr.indptr.Clone(),
csr.indptr.Clone(), csr.indices.Clone(), csr.indices.Clone(),
eid_array.Clone(), csr.sorted); NDArray::Empty({csr.indices->shape[0]},
csr.indices->dtype, csr.indices->ctx),
csr.sorted);
auto out_indices_data = static_cast<IdType *>(output.indices->data); auto out_indices_data = static_cast<IdType *>(output.indices->data);
auto out_eid_data = static_cast<IdType *>(output.data->data); auto out_eid_data = static_cast<IdType *>(output.data->data);
...@@ -114,8 +116,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -114,8 +116,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
std::vector<IdType> pointer(num_tags, 0); std::vector<IdType> pointer(num_tags, 0);
for (IdType ptr = start ; ptr < end ; ++ptr) { for (IdType ptr = start ; ptr < end ; ++ptr) {
const IdType dst = indices_data[ptr]; const IdType eid = eid_data ? eid_data[ptr] : ptr;
const TagType tag = tag_data[dst]; const TagType tag = tag_data[eid];
CHECK_LT(tag, num_tags); CHECK_LT(tag, num_tags);
++tag_pos_row[tag + 1]; ++tag_pos_row[tag + 1];
} // count } // count
...@@ -126,8 +128,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -126,8 +128,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
for (IdType ptr = start ; ptr < end ; ++ptr) { for (IdType ptr = start ; ptr < end ; ++ptr) {
const IdType dst = indices_data[ptr]; const IdType dst = indices_data[ptr];
const IdType eid = eid_data[ptr]; const IdType eid = eid_data ? eid_data[ptr] : ptr;
const TagType tag = tag_data[dst]; const TagType tag = tag_data[eid];
const IdType offset = tag_pos_row[tag] + pointer[tag]; const IdType offset = tag_pos_row[tag] + pointer[tag];
CHECK_LT(offset, tag_pos_row[tag + 1]); CHECK_LT(offset, tag_pos_row[tag + 1]);
++pointer[tag]; ++pointer[tag];
......
...@@ -54,17 +54,24 @@ def test_sort_with_tag(idtype): ...@@ -54,17 +54,24 @@ def test_sort_with_tag(idtype):
num_nodes, num_adj, num_tags = 200, [20, 50], 5 num_nodes, num_adj, num_tags = 200, [20, 50], 5
g = create_test_heterograph(num_nodes, num_adj, idtype=idtype) g = create_test_heterograph(num_nodes, num_adj, idtype=idtype)
tag = F.tensor(np.random.choice(num_tags, g.number_of_nodes())) tag = F.tensor(np.random.choice(num_tags, g.number_of_nodes()))
src, dst = g.edges()
edge_tag_dst = F.gather_row(tag, F.tensor(dst))
edge_tag_src = F.gather_row(tag, F.tensor(src))
new_g = dgl.sort_csr_by_tag(g, tag) for tag_type in ['node', 'edge']:
new_g = dgl.sort_csr_by_tag(
g, tag if tag_type == 'node' else edge_tag_dst, tag_type=tag_type)
old_csr = g.adjacency_matrix(scipy_fmt='csr') old_csr = g.adjacency_matrix(scipy_fmt='csr')
new_csr = new_g.adjacency_matrix(scipy_fmt='csr') new_csr = new_g.adjacency_matrix(scipy_fmt='csr')
assert(check_sort(new_csr, tag, new_g.ndata["_TAG_OFFSET"])) assert(check_sort(new_csr, tag, new_g.dstdata["_TAG_OFFSET"]))
assert(not check_sort(old_csr, tag)) # Check the original csr is not modified. assert(not check_sort(old_csr, tag)) # Check the original csr is not modified.
new_g = dgl.sort_csc_by_tag(g, tag) for tag_type in ['node', 'edge']:
new_g = dgl.sort_csc_by_tag(
g, tag if tag_type == 'node' else edge_tag_src, tag_type=tag_type)
old_csc = g.adjacency_matrix(transpose=True, scipy_fmt='csr') old_csc = g.adjacency_matrix(transpose=True, scipy_fmt='csr')
new_csc = new_g.adjacency_matrix(transpose=True, scipy_fmt='csr') new_csc = new_g.adjacency_matrix(transpose=True, scipy_fmt='csr')
assert(check_sort(new_csc, tag, new_g.ndata["_TAG_OFFSET"])) assert(check_sort(new_csc, tag, new_g.srcdata["_TAG_OFFSET"]))
assert(not check_sort(old_csc, tag)) assert(not check_sort(old_csc, tag))
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sorting by tag not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sorting by tag not implemented")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment