coalesce.py 646 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
import torch_scatter


rusty1s's avatar
rusty1s committed
5
def coalesce(index, value, m, n, op='add', fill_value=0):
rusty1s's avatar
rusty1s committed
6
7
    """Row-wise reorders and removes duplicate entries in sparse matrixx."""

rusty1s's avatar
rusty1s committed
8
9
    row, col = index

rusty1s's avatar
rusty1s committed
10
    unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
    perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
rusty1s's avatar
rusty1s committed
13
    perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)
rusty1s's avatar
rusty1s committed
14
15
16
    index = torch.stack([row[perm], col[perm]], dim=0)

    if value is not None:
rusty1s's avatar
rusty1s committed
17
18
        op = getattr(torch_scatter, 'scatter_{}'.format(op))
        value = op(value, inv, 0, None, perm.size(0), fill_value)
rusty1s's avatar
rusty1s committed
19
20

    return index, value