""" This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster """ import dgl import numpy as np import torch from sklearn import mixture from .density import density_to_peaks, density_to_peaks_vectorize __all__ = [ "peaks_to_labels", "edge_to_connected_graph", "decode", "build_next_level", ] def _find_parent(parent, u): idx = [] # parent is a fixed point while u != parent[u]: idx.append(u) u = parent[u] for i in idx: parent[i] = u return u def edge_to_connected_graph(edges, num): parent = list(range(num)) for u, v in edges: p_u = _find_parent(parent, u) p_v = _find_parent(parent, v) parent[p_u] = p_v for i in range(num): parent[i] = _find_parent(parent, i) remap = {} uf = np.unique(np.array(parent)) for i, f in enumerate(uf): remap[f] = i cluster_id = np.array([remap[f] for f in parent]) return cluster_id def peaks_to_edges(peaks, dist2peak, tau): edges = [] for src in peaks: dsts = peaks[src] dists = dist2peak[src] for dst, dist in zip(dsts, dists): if src == dst or dist >= 1 - tau: continue edges.append([src, dst]) return edges def peaks_to_labels(peaks, dist2peak, tau, inst_num): edges = peaks_to_edges(peaks, dist2peak, tau) pred_labels = edge_to_connected_graph(edges, inst_num) return pred_labels, edges def get_dists(g, nbrs, use_gt): k = nbrs.shape[1] src_id = nbrs[:, 1:].reshape(-1) dst_id = nbrs[:, 0].repeat(k - 1) eids = g.edge_ids(src_id, dst_id) if use_gt: new_dists = ( (1 - g.edata["labels_edge"][eids]).reshape(-1, k - 1).float() ) else: new_dists = g.edata["prob_conn"][eids, 0].reshape(-1, k - 1) ind = torch.argsort(new_dists, 1) offset = torch.LongTensor( (nbrs[:, 0] * (k - 1)).repeat(k - 1).reshape(-1, k - 1) ).to(g.device) ind = ind + offset nbrs = torch.LongTensor(nbrs).to(g.device) new_nbrs = torch.take(nbrs[:, 1:], ind) new_dists = torch.cat( [torch.zeros((new_dists.shape[0], 1)).to(g.device), new_dists], dim=1 ) new_nbrs = torch.cat( [torch.arange(new_nbrs.shape[0]).view(-1, 1).to(g.device), new_nbrs], dim=1, ) return new_nbrs.cpu().detach().numpy(), new_dists.cpu().detach().numpy() def get_edge_dist(g, threshold): if threshold == "prob": return g.edata["prob_conn"][:, 0] return 1 - g.edata["raw_affine"] def tree_generation(ng): ng.ndata["keep_eid"] = torch.zeros(ng.num_nodes()).long() - 1 def message_func(edges): return {"mval": edges.data["edge_dist"], "meid": edges.data[dgl.EID]} def reduce_func(nodes): ind = torch.min(nodes.mailbox["mval"], dim=1)[1] keep_eid = nodes.mailbox["meid"].gather(1, ind.view(-1, 1)) return {"keep_eid": keep_eid[:, 0]} node_order = dgl.traversal.topological_nodes_generator(ng) ng.prop_nodes(node_order, message_func, reduce_func) eids = ng.ndata["keep_eid"] eids = eids[eids > -1] edges = ng.find_edges(eids) treeg = dgl.graph(edges, num_nodes=ng.num_nodes()) return treeg def peak_propogation(treeg): treeg.ndata["pred_labels"] = torch.zeros(treeg.num_nodes()).long() - 1 peaks = torch.where(treeg.in_degrees() == 0)[0].cpu().numpy() treeg.ndata["pred_labels"][peaks] = torch.arange(peaks.shape[0]) def message_func(edges): return {"mlb": edges.src["pred_labels"]} def reduce_func(nodes): return {"pred_labels": nodes.mailbox["mlb"][:, 0]} node_order = dgl.traversal.topological_nodes_generator(treeg) treeg.prop_nodes(node_order, message_func, reduce_func) pred_labels = treeg.ndata["pred_labels"].cpu().numpy() return peaks, pred_labels def decode( g, tau, threshold, use_gt, ids=None, global_edges=None, global_num_nodes=None, global_peaks=None, ): # Edge filtering with tau and density den_key = "density" if use_gt else "pred_den" g = g.local_var() g.edata["edge_dist"] = get_edge_dist(g, threshold) g.apply_edges( lambda edges: { "keep": (edges.src[den_key] > edges.dst[den_key]).long() * (edges.data["edge_dist"] < 1 - tau).long() } ) eids = torch.where(g.edata["keep"] == 0)[0] ng = dgl.remove_edges(g, eids) # Tree generation ng.edata[dgl.EID] = torch.arange(ng.num_edges()) treeg = tree_generation(ng) # Label propogation peaks, pred_labels = peak_propogation(treeg) if ids is None: return pred_labels, peaks # Merge with previous layers src, dst = treeg.edges() new_global_edges = ( global_edges[0] + ids[src.numpy()].tolist(), global_edges[1] + ids[dst.numpy()].tolist(), ) global_treeg = dgl.graph(new_global_edges, num_nodes=global_num_nodes) global_peaks, global_pred_labels = peak_propogation(global_treeg) return ( pred_labels, peaks, new_global_edges, global_pred_labels, global_peaks, ) def build_next_level( features, labels, peaks, global_features, global_pred_labels, global_peaks ): global_peak_to_label = global_pred_labels[global_peaks] global_label_to_peak = np.zeros_like(global_peak_to_label) for i, pl in enumerate(global_peak_to_label): global_label_to_peak[pl] = i cluster_ind = np.split( np.argsort(global_pred_labels), np.unique(np.sort(global_pred_labels), return_index=True)[1][1:], ) cluster_features = np.zeros((len(peaks), global_features.shape[1])) for pi in range(len(peaks)): cluster_features[global_label_to_peak[pi], :] = np.mean( global_features[cluster_ind[pi], :], axis=0 ) features = features[peaks] labels = labels[peaks] return features, labels, cluster_features