coalesce.py 1.41 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
import torch_scatter

rusty1s's avatar
rusty1s committed
4
5
from .utils.unique import unique

rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
def coalesce(index, value, m, n, op='add'):
rusty1s's avatar
docs  
rusty1s committed
8
9
10
11
12
13
14
15
    """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
    entries are removed by scattering them together. For scattering, any
    operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
    can be used.

    Args:
        index (:class:`LongTensor`): The index tensor of sparse matrix.
        value (:class:`Tensor`): The value tensor of sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
16
17
        m (int): The first dimension of corresponding dense matrix.
        n (int): The second dimension of corresponding dense matrix.
rusty1s's avatar
docs  
rusty1s committed
18
19
20
21
22
        op (string, optional): The scatter operation to use. (default:
            :obj:`"add"`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
25
    row, col = index

rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
    if value is None:
        _, perm = unique(row * n + col)
        index = torch.stack([row[perm], col[perm]], dim=0)
        return index, value

    uniq, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
    perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
rusty1s's avatar
rusty1s committed
34
    perm = inv.new_empty(uniq.size(0)).scatter_(0, inv, perm)
rusty1s's avatar
rusty1s committed
35
36
    index = torch.stack([row[perm], col[perm]], dim=0)

rusty1s's avatar
rusty1s committed
37
    op = getattr(torch_scatter, 'scatter_{}'.format(op))
rusty1s's avatar
rusty1s committed
38
39
    value = op(value, inv, 0, None, perm.size(0))
    value = value[0] if isinstance(value, tuple) else value
rusty1s's avatar
rusty1s committed
40
41

    return index, value