deduce.py 5.84 KB
Newer Older
1
2
3
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
import dgl
5
6
import numpy as np
import torch
7
8
9
from sklearn import mixture

from .density import density_to_peaks, density_to_peaks_vectorize
10

11
12
13
14
15
16
__all__ = [
    "peaks_to_labels",
    "edge_to_connected_graph",
    "decode",
    "build_next_level",
]
17
18
19
20
21


def _find_parent(parent, u):
    idx = []
    # parent is a fixed point
22
    while u != parent[u]:
23
24
25
26
27
28
        idx.append(u)
        u = parent[u]
    for i in idx:
        parent[i] = u
    return u

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def edge_to_connected_graph(edges, num):
    parent = list(range(num))
    for u, v in edges:
        p_u = _find_parent(parent, u)
        p_v = _find_parent(parent, v)
        parent[p_u] = p_v

    for i in range(num):
        parent[i] = _find_parent(parent, i)
    remap = {}
    uf = np.unique(np.array(parent))
    for i, f in enumerate(uf):
        remap[f] = i
    cluster_id = np.array([remap[f] for f in parent])
    return cluster_id

46

47
48
49
50
51
52
53
54
55
56
57
def peaks_to_edges(peaks, dist2peak, tau):
    edges = []
    for src in peaks:
        dsts = peaks[src]
        dists = dist2peak[src]
        for dst, dist in zip(dsts, dists):
            if src == dst or dist >= 1 - tau:
                continue
            edges.append([src, dst])
    return edges

58

59
60
61
62
63
def peaks_to_labels(peaks, dist2peak, tau, inst_num):
    edges = peaks_to_edges(peaks, dist2peak, tau)
    pred_labels = edge_to_connected_graph(edges, inst_num)
    return pred_labels, edges

64

65
66
def get_dists(g, nbrs, use_gt):
    k = nbrs.shape[1]
67
68
    src_id = nbrs[:, 1:].reshape(-1)
    dst_id = nbrs[:, 0].repeat(k - 1)
69
70
    eids = g.edge_ids(src_id, dst_id)
    if use_gt:
71
72
73
        new_dists = (
            (1 - g.edata["labels_edge"][eids]).reshape(-1, k - 1).float()
        )
74
    else:
75
        new_dists = g.edata["prob_conn"][eids, 0].reshape(-1, k - 1)
76
    ind = torch.argsort(new_dists, 1)
77
78
79
    offset = torch.LongTensor(
        (nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1)
    ).to(g.device)
80
81
    ind = ind + offset
    nbrs = torch.LongTensor(nbrs).to(g.device)
82
83
84
85
86
87
88
89
    new_nbrs = torch.take(nbrs[:, 1:], ind)
    new_dists = torch.cat(
        [torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1
    )
    new_nbrs = torch.cat(
        [torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs],
        dim=1,
    )
90
91
    return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy()

92

93
def get_edge_dist(g, threshold):
94
95
96
97
    if threshold == "prob":
        return g.edata["prob_conn"][:, 0]
    return 1 - g.edata["raw_affine"]

98
99

def tree_generation(ng):
100
    ng.ndata["keep_eid"] = torch.zeros(ng.num_nodes()).long() - 1
101

102
    def message_func(edges):
103
104
        return {"mval": edges.data["edge_dist"], "meid": edges.data[dgl.EID]}

105
    def reduce_func(nodes):
106
107
108
109
        ind = torch.min(nodes.mailbox["mval"], dim=1)[1]
        keep_eid = nodes.mailbox["meid"].gather(1, ind.view(-1, 1))
        return {"keep_eid": keep_eid[:, 0]}

110
111
    node_order = dgl.traversal.topological_nodes_generator(ng)
    ng.prop_nodes(node_order, message_func, reduce_func)
112
    eids = ng.ndata["keep_eid"]
113
114
    eids = eids[eids > -1]
    edges = ng.find_edges(eids)
115
    treeg = dgl.graph(edges, num_nodes=ng.num_nodes())
116
117
    return treeg

118

119
def peak_propogation(treeg):
120
    treeg.ndata["pred_labels"] = torch.zeros(treeg.num_nodes()).long() - 1
121
    peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy()
122
123
    treeg.ndata["pred_labels"][peaks] = torch.arange(peaks.shape[0])

124
    def message_func(edges):
125
126
        return {"mlb": edges.src["pred_labels"]}

127
    def reduce_func(nodes):
128
129
        return {"pred_labels": nodes.mailbox["mlb"][:, 0]}

130
131
    node_order = dgl.traversal.topological_nodes_generator(treeg)
    treeg.prop_nodes(node_order, message_func, reduce_func)
132
    pred_labels = treeg.ndata["pred_labels"].cpu().numpy()
133
134
    return peaks, pred_labels

135
136
137
138
139
140
141
142
143
144
145

def decode(
    g,
    tau,
    threshold,
    use_gt,
    ids=None,
    global_edges=None,
    global_num_nodes=None,
    global_peaks=None,
):
146
    # Edge filtering with tau and density
147
    den_key = "density" if use_gt else "pred_den"
148
    g = g.local_var()
149
150
151
152
153
154
155
156
    g.edata["edge_dist"] = get_edge_dist(g, threshold)
    g.apply_edges(
        lambda edges: {
            "keep": (edges.src[den_key] > edges.dst[den_key]).long()
            * (edges.data["edge_dist"] < 1 - tau).long()
        }
    )
    eids = torch.where(g.edata["keep"] == 0)[0]
157
158
159
    ng = dgl.remove_edges(g, eids)

    # Tree generation
160
    ng.edata[dgl.EID] = torch.arange(ng.num_edges())
161
162
163
164
165
166
167
168
169
    treeg = tree_generation(ng)
    # Label propogation
    peaks, pred_labels = peak_propogation(treeg)

    if ids is None:
        return pred_labels, peaks

    # Merge with previous layers
    src, dst = treeg.edges()
170
171
172
173
    new_global_edges = (
        global_edges[0] + ids[src.numpy()].tolist(),
        global_edges[1] + ids[dst.numpy()].tolist(),
    )
174
175
    global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes)
    global_peaks, global_pred_labels = peak_propogation(global_treeg)
176
177
178
179
180
181
182
183
184
185
186
187
    return (
        pred_labels,
        peaks,
        new_global_edges,
        global_pred_labels,
        global_peaks,
    )


def build_next_level(
    features, labels, peaks, global_features, global_pred_labels, global_peaks
):
188
189
190
191
    global_peak_to_label = global_pred_labels[global_peaks]
    global_label_to_peak = np.zeros_like(global_peak_to_label)
    for i, pl in enumerate(global_peak_to_label):
        global_label_to_peak[pl] = i
192
193
194
195
    cluster_ind = np.split(
        np.argsort(global_pred_labels),
        np.unique(np.sort(global_pred_labels), return_index=True)[1][1:],
    )
196
197
    cluster_features = np.zeros((len(peaks), global_features.shape[1]))
    for pi in range(len(peaks)):
198
199
200
        cluster_features[global_label_to_peak[pi], :] = np.mean(
            global_features[cluster_ind[pi], :], axis=0
        )
201
202
203
    features = features[peaks]
    labels = labels[peaks]
    return features, labels, cluster_features