dataset.py 3.81 KB
Newer Older
1
2
import pickle

3
import numpy as np
4
import torch
5
6
7
8
9
10
11
12
13
14
15
16
17
from utils import (
    build_knns,
    build_next_level,
    decode,
    density_estimation,
    fast_knns2spmat,
    knns2ordered_nbrs,
    l2norm,
    row_normalize,
    sparse_mx_to_indices_values,
)

import dgl
18
19
20


class LanderDataset(object):
21
22
23
24
25
26
27
28
29
    def __init__(
        self,
        features,
        labels,
        cluster_features=None,
        k=10,
        levels=1,
        faiss_gpu=False,
    ):
30
31
32
33
34
35
36
        self.k = k
        self.gs = []
        self.nbrs = []
        self.dists = []
        self.levels = levels

        # Initialize features and labels
37
        features = l2norm(features.astype("float32"))
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        global_features = features.copy()
        if cluster_features is None:
            cluster_features = features
        global_num_nodes = features.shape[0]
        global_edges = ([], [])
        global_peaks = np.array([], dtype=np.long)
        ids = np.arange(global_num_nodes)

        # Recursive graph construction
        for lvl in range(self.levels):
            if features.shape[0] <= self.k:
                self.levels = lvl
                break
            if faiss_gpu:
52
                knns = build_knns(features, self.k, "faiss_gpu")
53
            else:
54
                knns = build_knns(features, self.k, "faiss")
55
56
57
58
59
            dists, nbrs = knns2ordered_nbrs(knns)
            self.nbrs.append(nbrs)
            self.dists.append(dists)
            density = density_estimation(dists, nbrs, labels)

60
61
62
            g = self._build_graph(
                features, cluster_features, labels, density, knns
            )
63
64
65
66
67
68
            self.gs.append(g)

            if lvl >= self.levels - 1:
                break

            # Decode peak nodes
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
            (
                new_pred_labels,
                peaks,
                global_edges,
                global_pred_labels,
                global_peaks,
            ) = decode(
                g,
                0,
                "sim",
                True,
                ids,
                global_edges,
                global_num_nodes,
                global_peaks,
            )
85
            ids = ids[peaks]
86
87
88
89
90
91
92
93
            features, labels, cluster_features = build_next_level(
                features,
                labels,
                peaks,
                global_features,
                global_pred_labels,
                global_peaks,
            )
94
95
96
97
98
99
100

    def _build_graph(self, features, cluster_features, labels, density, knns):
        adj = fast_knns2spmat(knns, self.k)
        adj, adj_row_sum = row_normalize(adj)
        indices, values, shape = sparse_mx_to_indices_values(adj)

        g = dgl.graph((indices[1], indices[0]))
101
102
103
104
105
        g.ndata["features"] = torch.FloatTensor(features)
        g.ndata["cluster_features"] = torch.FloatTensor(cluster_features)
        g.ndata["labels"] = torch.LongTensor(labels)
        g.ndata["density"] = torch.FloatTensor(density)
        g.edata["affine"] = torch.FloatTensor(values)
106
        # A Bipartite from DGL sampler will not store global eid, so we explicitly save it here
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        g.edata["global_eid"] = g.edges(form="eid")
        g.ndata["norm"] = torch.FloatTensor(adj_row_sum)
        g.apply_edges(
            lambda edges: {
                "raw_affine": edges.data["affine"] / edges.dst["norm"]
            }
        )
        g.apply_edges(
            lambda edges: {
                "labels_conn": (
                    edges.src["labels"] == edges.dst["labels"]
                ).long()
            }
        )
        g.apply_edges(
            lambda edges: {
                "mask_conn": (
                    edges.src["density"] > edges.dst["density"]
                ).bool()
            }
        )
128
129
130
131
132
133
134
135
        return g

    def __getitem__(self, index):
        assert index < len(self.gs)
        return self.gs[index]

    def __len__(self):
        return len(self.gs)