sparse.py 1.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch


rusty1s's avatar
docs  
rusty1s committed
4
5
6
class SparseCooTensor(torch.autograd.Function):
    """Constructs Sparse matrix with autograd capabilities w.r.t. to value."""

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    @staticmethod
    def forward(ctx, index, value, size):
        ctx.size = size
        ctx.save_for_backward(index)
        return torch.sparse_coo_tensor(index, value, size, device=value.device)

    @staticmethod
    def backward(ctx, grad_out):
        index = ctx.saved_variables[0]
        grad_in = None

        if ctx.needs_input_grad[1]:
            value = grad_out._values()
            id1 = index[0] * ctx.size[1] + index[1]
            index = grad_out._indices()
            id2 = index[0] * ctx.size[1] + index[1]

            grad_in = value.new_zeros(id1.max().item() + 1)
            grad_in[id2] = value
            grad_in = grad_in[id1]

        return None, grad_in, None


rusty1s's avatar
docs  
rusty1s committed
31
sparse_coo_tensor = SparseCooTensor.apply
rusty1s's avatar
rusty1s committed
32
33
34


class ToValue(torch.autograd.Function):
rusty1s's avatar
docs  
rusty1s committed
35
36
    """Extract values of sparse tensors with autograd support."""

rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    @staticmethod
    def forward(ctx, A):
        ctx.save_for_backward(A)
        return A._values()

    @staticmethod
    def backward(ctx, grad_out):
        A = ctx.saved_variables[0]
        grad_in = None

        if ctx.needs_input_grad[0]:
            grad_in = torch.sparse_coo_tensor(
                A._indices(), grad_out, A.size(), device=grad_out.device)

        return grad_in


to_value = ToValue.apply