test_sort.py 3.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import dgl
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
from utils import parametrize_dtype

def create_test_heterograph(num_nodes, num_adj, idtype):
    if isinstance(num_adj, int):
        num_adj = [num_adj, num_adj+1]
    num_adj_list = list(np.random.choice(np.arange(num_adj[0], num_adj[1]), num_nodes))
    src = np.concatenate([[i] * num_adj_list[i] for i in range(num_nodes)])
    dst = [np.random.choice(num_nodes, nadj, replace=False) for nadj in num_adj_list]
    dst = np.concatenate(dst)
    return dgl.graph((src, dst), idtype=idtype)

def check_sort(spm, tag_arr=None, tag_pos=None):
    if tag_arr is None:
        tag_arr = np.arange(spm.shape[0])
    else:
        tag_arr = F.asnumpy(tag_arr)
    if tag_pos is not None:
        tag_pos = F.asnumpy(tag_pos)
    for i in range(spm.shape[0]):
        row = spm.getrow(i)
        dst = row.nonzero()[1]
        if tag_pos is not None:
            tag_pos_row = tag_pos[i]
            tag_pos_ptr = tag_arr[dst[0]] if len(dst) > 0 else 0
        for j in range(len(dst) - 1):
            if tag_pos is not None and tag_arr[dst[j]] != tag_pos_ptr:
                # `tag_pos_ptr` is the expected tag value. Here we check whether the
                # tag value is equal to `tag_pos_ptr`
                return False
            if tag_arr[dst[j]] > tag_arr[dst[j+1]]:
                # The tag should be in descending order after sorting
                return False
            if tag_pos is not None and tag_arr[dst[j]] < tag_arr[dst[j+1]]:
                if j+1 != int(tag_pos_row[tag_pos_ptr+1]):
                    # The boundary of tag should be consistent with `tag_pos`
                    return False
                tag_pos_ptr = tag_arr[dst[j+1]]
    return True


@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sorting by tag not implemented")
@parametrize_dtype
def test_sort_with_tag(idtype):
    num_nodes, num_adj, num_tags = 200, [20, 50], 5
    g = create_test_heterograph(num_nodes, num_adj, idtype=idtype)
    tag = F.tensor(np.random.choice(num_tags, g.number_of_nodes()))

    new_g = dgl.sort_out_edges(g, tag)
    old_csr = g.adjacency_matrix(scipy_fmt='csr')
    new_csr = new_g.adjacency_matrix(scipy_fmt='csr')
    assert(check_sort(new_csr, tag, new_g.ndata["_TAG_OFFSET"]))
    assert(not check_sort(old_csr, tag))  # Check the original csr is not modified.

    new_g = dgl.sort_in_edges(g, tag)
    old_csc = g.adjacency_matrix(transpose=False, scipy_fmt='csr')
    new_csc = new_g.adjacency_matrix(transpose=False, scipy_fmt='csr')
    assert(check_sort(new_csc, tag, new_g.ndata["_TAG_OFFSET"]))
    assert(not check_sort(old_csc, tag))

@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sorting by tag not implemented")
@parametrize_dtype
def test_sort_with_tag_bipartite(idtype):
    num_nodes, num_adj, num_tags = 200, [20, 50], 5
    g = create_test_heterograph(num_nodes, num_adj, idtype=idtype)
    g = dgl.heterograph({('_U', '_E', '_V') : g.edges()})
    utag = F.tensor(np.random.choice(num_tags, g.number_of_nodes('_U')))
    vtag = F.tensor(np.random.choice(num_tags, g.number_of_nodes('_V')))

    new_g = dgl.sort_out_edges(g, vtag)
    old_csr = g.adjacency_matrix(scipy_fmt='csr')
    new_csr = new_g.adjacency_matrix(scipy_fmt='csr')
    assert(check_sort(new_csr, vtag, new_g.nodes['_U'].data['_TAG_OFFSET']))
    assert(not check_sort(old_csr, vtag))

    new_g = dgl.sort_in_edges(g, utag)
    old_csc = g.adjacency_matrix(transpose=False, scipy_fmt='csr')
    new_csc = new_g.adjacency_matrix(transpose=False, scipy_fmt='csr')
    assert(check_sort(new_csc, utag, new_g.nodes['_V'].data['_TAG_OFFSET']))
    assert(not check_sort(old_csc, utag))

if __name__ == "__main__":
    test_sort_with_tag(F.int32)
    test_sort_with_tag_bipartite(F.int32)