"vscode:/vscode.git/clone" did not exist on "20f6c0656906f103d1962b67789a8b7ae8515514"
test_mm.py 3.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys

import backend as F
import pytest
import torch

from dgl.mock_sparse2 import create_from_coo, create_from_csc, create_from_csr

# TODO(#4818): Skipping tests on win.
if not sys.platform.startswith("linux"):
    pytest.skip("skipping tests on win", allow_module_level=True)


def get_adj(A):
    row, col = A.coo()
    edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
    shape = A.shape
    val = A.val.detach()
    if len(A.val.shape) > 1:
        shape += (A.val.shape[-1],)
    return torch.sparse_coo_tensor(edge_index, val, shape).coalesce()


def test_spmm_coo():
    dev = F.ctx()
    # A: shape (N, M), X: shape (M, F)
    row = torch.tensor([0, 1, 1, 1]).to(dev)
    col = torch.tensor([1, 0, 1, 2]).to(dev)
    val = torch.randn(len(row), requires_grad=True, device=dev)
    A = create_from_coo(row, col, val)
    X = torch.randn(3, 4, requires_grad=True, device=dev)
    sparse_result = A @ X
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    adj = get_adj(A)
    adj.requires_grad_()
    XX = X.clone().detach()
    XX.requires_grad_()
    dense_result = torch.sparse.mm(adj, XX)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result)
    assert torch.allclose(X.grad, XX.grad)
    assert torch.allclose(adj.grad.coalesce().values(), val.grad)


def test_spmm_coo_one_dim_rhs():
    dev = F.ctx()
    # A: shape (N, M), X: shape (M,)
    row = torch.tensor([0, 1, 1, 1]).to(dev)
    col = torch.tensor([1, 0, 1, 2]).to(dev)
    val = torch.randn(len(row), requires_grad=True, device=dev)
    A = create_from_coo(row, col, val)
    X = torch.randn(3, requires_grad=True, device=dev)
    sparse_result = A @ X
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    adj = get_adj(A)
    adj.requires_grad_()
    XX = X.clone().detach()
    XX.requires_grad_()
    dense_result = torch.sparse.mm(adj, XX.view(-1, 1))
    dense_result = dense_result.view(-1)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result)
    assert torch.allclose(X.grad, XX.grad)
    assert torch.allclose(adj.grad.coalesce().values(), val.grad)


def test_spmm_csr():
    dev = F.ctx()
    # A: shape (N, M), X: shape (M, F)
    indptr = torch.tensor([0, 1, 4]).to(dev)
    indices = torch.tensor([1, 0, 1, 2]).to(dev)
    val = torch.randn(len(indices), requires_grad=True, device=dev)
    A = create_from_csr(indptr, indices, val, shape=(2, 3))
    X = torch.randn(3, 4, requires_grad=True, device=dev)
    sparse_result = A @ X
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    adj = get_adj(A)
    adj.requires_grad_()
    XX = X.clone().detach()
    XX.requires_grad_()
    dense_result = torch.sparse.mm(adj, XX)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result)
    assert torch.allclose(X.grad, XX.grad)
    assert torch.allclose(adj.grad.coalesce().values(), val.grad)


def test_spmm_csc():
    dev = F.ctx()
    # A: shape (N, M), X: shape (M, F)
    indptr = torch.tensor([0, 1, 3, 4]).to(dev)
    indices = torch.tensor([0, 0, 1, 1]).to(dev)
    val = torch.randn(len(indices), requires_grad=True, device=dev)
    A = create_from_csc(indptr, indices, val, shape=(2, 3))
    X = torch.randn(3, 4, requires_grad=True, device=dev)
    sparse_result = A @ X
    grad = torch.randn_like(sparse_result)
    sparse_result.backward(grad)

    adj = get_adj(A)
    adj.requires_grad_()
    XX = X.clone().detach()
    XX.requires_grad_()
    dense_result = torch.sparse.mm(adj, XX)
    dense_result.backward(grad)
    assert torch.allclose(sparse_result, dense_result)
    assert torch.allclose(X.grad, XX.grad)
    assert torch.allclose(adj.grad.coalesce().values(), val.grad)