density.py 2.98 KB
Newer Older
1
2
3
4
5
6
7
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

from itertools import groupby
8
9

import numpy as np
10
import torch
11
12
13
14
15
16
17
from tqdm import tqdm

__all__ = [
    "density_estimation",
    "density_to_peaks",
    "density_to_peaks_vectorize",
]
18
19
20


def density_estimation(dists, nbrs, labels, **kwargs):
21
    """use supervised density defined on neigborhood"""
22
    num, k_knn = dists.shape
23
    conf = np.ones((num,), dtype=np.float32)
24
    ind_array = labels[nbrs] == np.expand_dims(labels, 1).repeat(k_knn, 1)
25
26
    pos = ((1 - dists[:, 1:]) * ind_array[:, 1:]).sum(1)
    neg = ((1 - dists[:, 1:]) * (1 - ind_array[:, 1:])).sum(1)
27
    conf = (pos - neg) * conf
28
    conf /= k_knn - 1
29
30
    return conf

31
32

def density_to_peaks_vectorize(dists, nbrs, density, max_conn=1, name=""):
33
34
35
36
37
38
    # just calculate 1 connectivity
    assert dists.shape[0] == density.shape[0]
    assert dists.shape == nbrs.shape

    num, k = dists.shape

39
    if name == "gcn_feat":
40
        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
41
42
43
        secondary_mask = (
            np.sum(include_mask, axis=1) == k
        )  # TODO: the condition == k should not happen as distance to the node self should be smallest, check for numerical stability; TODO: make top M instead of only supporting top 1
44
        include_mask[secondary_mask, -1] = False
45
46
        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)
        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)
47
48
    else:
        include_mask = nbrs != np.arange(0, num).reshape(-1, 1)
49
50
        nbrs_exclude_self = nbrs[include_mask].reshape(-1, k - 1)  # (V, 79)
        dists_exclude_self = dists[include_mask].reshape(-1, k - 1)  # (V, 79)
51
52

    compare_map = density[nbrs_exclude_self] > density.reshape(-1, 1)
53
54
    peak_index = np.argmax(np.where(compare_map, 1, 0), axis=1)  # (V,)
    compare_map_sum = np.sum(compare_map.cpu().data.numpy(), axis=1)  # (V,)
55

56
57
58
59
60
61
62
63
64
65
66
67
    dist2peak = {
        i: []
        if compare_map_sum[i] == 0
        else [dists_exclude_self[i, peak_index[i]]]
        for i in range(num)
    }
    peaks = {
        i: []
        if compare_map_sum[i] == 0
        else [nbrs_exclude_self[i, peak_index[i]]]
        for i in range(num)
    }
68
69
70

    return dist2peak, peaks

71
72

def density_to_peaks(dists, nbrs, density, max_conn=1, sort="dist"):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    # Note that dists has been sorted in ascending order
    assert dists.shape[0] == density.shape[0]
    assert dists.shape == nbrs.shape

    num, _ = dists.shape
    dist2peak = {i: [] for i in range(num)}
    peaks = {i: [] for i in range(num)}

    for i, nbr in tqdm(enumerate(nbrs)):
        nbr_conf = density[nbr]
        for j, c in enumerate(nbr_conf):
            nbr_idx = nbr[j]
            if i == nbr_idx or c <= density[i]:
                continue
            dist2peak[i].append(dists[i, j])
            peaks[i].append(nbr_idx)
            if len(dist2peak[i]) >= max_conn:
                break

    return dist2peak, peaks