"docs/vscode:/vscode.git/clone" did not exist on "d248e7686f8487438261f34c49217b88c2a6dfbe"
test_reduction.py 4.76 KB
Newer Older
1
2
3
4
5
6
import doctest
import operator
import sys

import backend as F

7
import dgl.sparse as dglsp
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pytest
import torch

dgl_op_map = {
    "sum": "sum",
    "amin": "smin",
    "amax": "smax",
    "mean": "smean",
    "prod": "sprod",
}
default_entry = {
    "sum": 0,
    "amin": float("inf"),
    "amax": float("-inf"),
    "mean": 0,
    "prod": 1,
}
binary_op_map = {
    "sum": operator.add,
    "amin": torch.min,
    "amax": torch.max,
    "mean": operator.add,
    "prod": operator.mul,
}

NUM_ROWS = 10
NUM_COLS = 15


def _coalesce_dense(row, col, val, nrows, ncols, op):
    # Sparse matrix coalescing on a dense matrix.
    #
    # It is done by stacking every non-zero entry on an individual slice
    # of an (nrows x ncols x nnz), that is, construct a tensor A with
    # shape (nrows, ncols, len(val)) where
    #
    #     A[row[i], col[i], i] = val[i]
    #
    # and then reducing on the third "nnz" dimension.
    #
    # The mask matrix M has the same sparsity pattern as A with 1 being
    # the non-zero entries.  This is used for division if the reduce
    # operator is mean.
    M = torch.zeros(NUM_ROWS, NUM_COLS, device=F.ctx())
    A = torch.full(
        (NUM_ROWS, NUM_COLS, 20) + val.shape[1:],
        default_entry[op],
        device=F.ctx(),
        dtype=val.dtype,
    )
    A = torch.index_put(A, (row, col, torch.arange(20)), val)
    for i in range(20):
        M[row[i], col[i]] += 1
    if op == "mean":
        A = A.sum(2)
    else:
        A = getattr(A, op)(2)
    M = M.view(NUM_ROWS, NUM_COLS, *([1] * (val.dim() - 1)))
    return A, M


# Add docstring tests of dglsp.reduction to unit tests
@pytest.mark.parametrize(
    "func", ["reduce", "sum", "smin", "smax", "sprod", "smean"]
)
def test_docstring(func):
    globs = {"torch": torch, "dglsp": dglsp}
    runner = doctest.DebugRunner()
    finder = doctest.DocTestFinder()
    obj = getattr(dglsp, func)
    for test in finder.find(obj, func, globs=globs):
        runner.run(test)


@pytest.mark.parametrize("shape", [(20,), (20, 20)])
@pytest.mark.parametrize("op", ["sum", "amin", "amax", "mean", "prod"])
@pytest.mark.parametrize("use_reduce", [False, True])
def test_reduce_all(shape, op, use_reduce):
    row = torch.randint(0, NUM_ROWS, (20,), device=F.ctx())
    col = torch.randint(0, NUM_COLS, (20,), device=F.ctx())
    val = torch.randn(*shape, device=F.ctx())
    val2 = val.clone()
    val = val.requires_grad_()
    val2 = val2.requires_grad_()
92
    A = dglsp.from_coo(row, col, val, shape=(NUM_ROWS, NUM_COLS))
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    A2, M = _coalesce_dense(row, col, val2, NUM_ROWS, NUM_COLS, op)

    if not use_reduce:
        output = getattr(A, dgl_op_map[op])()
    else:
        output = A.reduce(rtype=dgl_op_map[op])

    if op == "mean":
        output2 = A2.sum((0, 1)) / M.sum()
    elif op == "prod":
        output2 = A2.prod(0).prod(0)  # prod() does not support tuple of dims
    else:
        output2 = getattr(A2, op)((0, 1))
    assert (output - output2).abs().max() < 1e-4

    head = torch.randn(*output.shape).to(val) if output.dim() > 0 else None
    output.backward(head)
    output2.backward(head)
    assert (val.grad - val2.grad).abs().max() < 1e-4


@pytest.mark.parametrize("shape", [(20,), (20, 20)])
@pytest.mark.parametrize("dim", [0, 1])
@pytest.mark.parametrize("empty_nnz", [False, True])
@pytest.mark.parametrize("op", ["sum", "amin", "amax", "mean", "prod"])
@pytest.mark.parametrize("use_reduce", [False, True])
def test_reduce_along(shape, dim, empty_nnz, op, use_reduce):
    row = torch.randint(0, NUM_ROWS, (20,), device=F.ctx())
    col = torch.randint(0, NUM_COLS, (20,), device=F.ctx())
    if dim == 0:
        mask = torch.bincount(col, minlength=NUM_COLS) == 0
    else:
        mask = torch.bincount(row, minlength=NUM_ROWS) == 0
    val = torch.randn(*shape, device=F.ctx())
    val2 = val.clone()
    val = val.requires_grad_()
    val2 = val2.requires_grad_()

    # empty_nnz controls whether at least one column or one row has no
    # non-zero entry.
    if empty_nnz:
        row[row == 0] = 1
        col[col == 0] = 1

138
    A = dglsp.from_coo(row, col, val, shape=(NUM_ROWS, NUM_COLS))
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    A2, M = _coalesce_dense(row, col, val2, NUM_ROWS, NUM_COLS, op)

    if not use_reduce:
        output = getattr(A, dgl_op_map[op])(dim)
    else:
        output = A.reduce(dim=dim, rtype=dgl_op_map[op])

    if op == "mean":
        output2 = A2.sum(dim) / M.sum(dim)
    else:
        output2 = getattr(A2, op)(dim)
    zero_entry_idx = (M.sum(dim) != 0).nonzero(as_tuple=True)[0]
    output3 = torch.index_put(
        torch.zeros_like(output2), (zero_entry_idx,), output2[zero_entry_idx]
    )
    assert (output - output3).abs().max() < 1e-4

    head = torch.randn(*output.shape).to(val) if output.dim() > 0 else None
    output.backward(head)
    output3.backward(head)
    assert (val.grad - val2.grad).abs().max() < 1e-4