Commit b8fb93bd authored by rusty1s's avatar rusty1s
Browse files

clean up

parent c3155ab6
...@@ -6,16 +6,14 @@ def coalesce(index, value, size, op='add', fill_value=0): ...@@ -6,16 +6,14 @@ def coalesce(index, value, size, op='add', fill_value=0):
m, n = size m, n = size
row, col = index row, col = index
index = row * n + col unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
unique, inv = torch.unique(index, sorted=True, return_inverse=True)
perm = torch.arange(index.size(0), dtype=index.dtype, device=index.device) perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = index.new_empty(inv.max().item() + 1).scatter_(0, inv, perm) perm = inv.new_empty(inv.max().item() + 1).scatter_(0, inv, perm)
index = torch.stack([row[perm], col[perm]], dim=0) index = torch.stack([row[perm], col[perm]], dim=0)
if value is not None: if value is not None:
scatter = getattr(torch_scatter, 'scatter_{}'.format(op)) op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = scatter( value = op(value, inv, 0, None, perm.size(0), fill_value)
value, inv, dim=0, dim_size=perm.size(0), fill_value=fill_value)
return index, value return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment