test_matmul.py 2 KB
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
from itertools import product

import pytest
rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
from torch_sparse import SparseTensor, spspmm, to_value
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
to csr  
rusty1s committed
8
9
from .utils import dtypes, devices, tensor

rusty1s's avatar
rusty1s committed
10
devices = [torch.device('cpu')]
rusty1s's avatar
rusty1s committed
11
dtypes = [torch.double]
rusty1s's avatar
rusty1s committed
12
13


rusty1s's avatar
rusty1s committed
14
15
16
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesced_spspmm(dtype, device):
    indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
rusty1s's avatar
rusty1s committed
17
    valueA = tensor([1, 2, 3, 4, 5], dtype, device)
rusty1s's avatar
rusty1s committed
18
    sizeA = torch.Size([3, 3])
rusty1s's avatar
rusty1s committed
19
    A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
rusty1s's avatar
rusty1s committed
20
21

    indexB = torch.tensor([[0, 2], [1, 0]], device=device)
rusty1s's avatar
rusty1s committed
22
    valueB = tensor([2, 4], dtype, device)
rusty1s's avatar
rusty1s committed
23
    sizeB = torch.Size([3, 2])
rusty1s's avatar
rusty1s committed
24
    B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
    assert spspmm(A, B).to_dense().tolist() == [[8, 0], [0, 6], [0, 8]]
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
29
    # A.requires_grad_()
    # B.requires_grad_()
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
    # A.requires_grad_()
    # B.requires_grad_()
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
36
37
38
    # to_value(C).sum().backward()
    # print(valueA)
    # print(valueA.grad)
    # print(valueB)
    # print(valueB.grad)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
    # A_dense.requires_grad_()
    # B_dense.requires_grad_()
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
45
46
47
    # C_dense = torch.matmul(A_dense, B_dense)
    # C_dense[C_dense > 0].sum().backward()
    # print(A_dense)
    # print(A_dense.grad)
    # print(B_dense)
rusty1s's avatar
rusty1s committed
48
    # print(B_dense.grad)
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    # A.requires_grad_()
    # B = B.to_dense()
    # B.requires_grad_()
    # torch.spmm(A, B).sum().backward()
    # print(B.grad)

    # valueA.requires_grad_()
    valueB.requires_grad_()

    def pipeline(valueA, valueB):
        A = SparseTensor(indexA, valueA, sizeA)
        B = SparseTensor(indexB, valueB, sizeB)
        C = spspmm(A, B)
        value = to_value(C)
        return value

    # out = pipeline(valueA, valueB).sum().backward()
    # print(valueA.grad)
    # print(valueB.grad)

    print(gradcheck(pipeline, (valueA, valueB), eps=1e-6, atol=1e-4))

    # A, B = Sparsetensor(SparseTensor(index, valueB, sizeB)
    # print(A.requires_grad)

    # to_value(C).sum().backward()