dataset.py 3.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import pickle

import dgl
import torch

from utils import (build_knns, fast_knns2spmat, row_normalize, knns2ordered_nbrs,
                   density_estimation, sparse_mx_to_indices_values, l2norm,
                   decode, build_next_level)

class LanderDataset(object):
    def __init__(self, features, labels, cluster_features=None, k=10, levels=1, faiss_gpu=False):
        self.k = k
        self.gs = []
        self.nbrs = []
        self.dists = []
        self.levels = levels

        # Initialize features and labels
        features = l2norm(features.astype('float32'))
        global_features = features.copy()
        if cluster_features is None:
            cluster_features = features
        global_num_nodes = features.shape[0]
        global_edges = ([], [])
        global_peaks = np.array([], dtype=np.long)
        ids = np.arange(global_num_nodes)

        # Recursive graph construction
        for lvl in range(self.levels):
            if features.shape[0] <= self.k:
                self.levels = lvl
                break
            if faiss_gpu:
                knns = build_knns(features, self.k, 'faiss_gpu')
            else:
                knns = build_knns(features, self.k, 'faiss')
            dists, nbrs = knns2ordered_nbrs(knns)
            self.nbrs.append(nbrs)
            self.dists.append(dists)
            density = density_estimation(dists, nbrs, labels)

            g = self._build_graph(features, cluster_features, labels, density, knns)
            self.gs.append(g)

            if lvl >= self.levels - 1:
                break

            # Decode peak nodes
            new_pred_labels, peaks,\
                global_edges, global_pred_labels, global_peaks = decode(g, 0, 'sim', True,
                                                                        ids, global_edges, global_num_nodes,
                                                                        global_peaks)
            ids = ids[peaks]
            features, labels, cluster_features = build_next_level(features, labels, peaks,
                                                                  global_features, global_pred_labels, global_peaks)

    def _build_graph(self, features, cluster_features, labels, density, knns):
        adj = fast_knns2spmat(knns, self.k)
        adj, adj_row_sum = row_normalize(adj)
        indices, values, shape = sparse_mx_to_indices_values(adj)

        g = dgl.graph((indices[1], indices[0]))
        g.ndata['features'] = torch.FloatTensor(features)
        g.ndata['cluster_features'] = torch.FloatTensor(cluster_features)
        g.ndata['labels'] = torch.LongTensor(labels)
        g.ndata['density'] = torch.FloatTensor(density)
        g.edata['affine'] = torch.FloatTensor(values)
        # A Bipartite from DGL sampler will not store global eid, so we explicitly save it here
        g.edata['global_eid'] = g.edges(form='eid')
        g.ndata['norm'] = torch.FloatTensor(adj_row_sum)
        g.apply_edges(lambda edges: {'raw_affine': edges.data['affine'] / edges.dst['norm']})
        g.apply_edges(lambda edges: {'labels_conn': (edges.src['labels'] == edges.dst['labels']).long()})
        g.apply_edges(lambda edges: {'mask_conn': (edges.src['density'] > edges.dst['density']).bool()})
        return g

    def __getitem__(self, index):
        assert index < len(self.gs)
        return self.gs[index]

    def __len__(self):
        return len(self.gs)