cycles.py 6.26 KB
Newer Older
1
2
3
import os
import pickle
import random
4
5
6

import matplotlib.pyplot as plt
import networkx as nx
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch.utils.data import Dataset


def get_previous(i, v_max):
    if i == 0:
        return v_max
    else:
        return i - 1


def get_next(i, v_max):
    if i == v_max:
        return 0
    else:
        return i + 1


def is_cycle(g):
25
    size = g.num_nodes()
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    if size < 3:
        return False

    for node in range(size):
        neighbors = g.successors(node)

        if len(neighbors) != 2:
            return False

        if get_previous(node, size - 1) not in neighbors:
            return False

        if get_next(node, size - 1) not in neighbors:
            return False

    return True


def get_decision_sequence(size):
    """
    Get the decision sequence for generating valid cycles with DGMG for teacher
    forcing optimization.
    """
    decision_sequence = []

    for i in range(size):
        decision_sequence.append(0)  # Add node

        if i != 0:
            decision_sequence.append(0)  # Add edge
57
58
59
            decision_sequence.append(
                i - 1
            )  # Set destination to be previous node.
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        if i == size - 1:
            decision_sequence.append(0)  # Add edge
            decision_sequence.append(0)  # Set destination to be the root.

        decision_sequence.append(1)  # Stop adding edge

    decision_sequence.append(1)  # Stop adding node

    return decision_sequence


def generate_dataset(v_min, v_max, n_samples, fname):
    samples = []
    for _ in range(n_samples):
        size = random.randint(v_min, v_max)
        samples.append(get_decision_sequence(size))

78
    with open(fname, "wb") as f:
79
80
81
82
83
84
85
        pickle.dump(samples, f)


class CycleDataset(Dataset):
    def __init__(self, fname):
        super(CycleDataset, self).__init__()

86
        with open(fname, "rb") as f:
87
88
89
90
91
92
93
94
            self.dataset = pickle.load(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

Mufei Li's avatar
Mufei Li committed
95
    def collate_single(self, batch):
96
        assert len(batch) == 1, "Currently we do not support batched training"
97
98
        return batch[0]

Mufei Li's avatar
Mufei Li committed
99
100
101
    def collate_batch(self, batch):
        return batch

102
103
104

def dglGraph_to_adj_list(g):
    adj_list = {}
105
    for node in range(g.num_nodes()):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # For undirected graph. successors and
        # predecessors are equivalent.
        adj_list[node] = g.successors(node).tolist()
    return adj_list


class CycleModelEvaluation(object):
    def __init__(self, v_min, v_max, dir):
        super(CycleModelEvaluation, self).__init__()

        self.v_min = v_min
        self.v_max = v_max

        self.dir = dir

    def rollout_and_examine(self, model, num_samples):
122
        assert not model.training, "You need to call model.eval()."
123
124
125
126
127
128
129
130
131
132

        num_total_size = 0
        num_valid_size = 0
        num_cycle = 0
        num_valid = 0
        plot_times = 0
        adj_lists_to_plot = []

        for i in range(num_samples):
            sampled_graph = model()
Mufei Li's avatar
Mufei Li committed
133
134
135
136
137
138
139
140
            if isinstance(sampled_graph, list):
                # When the model is a batched implementation, a list of
                # DGLGraph objects is returned. Note that with model(),
                # we generate a single graph as with the non-batched
                # implementation. We actually support batched generation
                # during the inference so feel free to modify the code.
                sampled_graph = sampled_graph[0]

141
142
143
            sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
            adj_lists_to_plot.append(sampled_adj_list)

144
            graph_size = sampled_graph.num_nodes()
145
            valid_size = self.v_min <= graph_size <= self.v_max
146
147
            cycle = is_cycle(sampled_graph)

Mufei Li's avatar
Mufei Li committed
148
            num_total_size += graph_size
149
150
151
152
153
154
155
156
157
158

            if valid_size:
                num_valid_size += 1

            if cycle:
                num_cycle += 1

            if valid_size and cycle:
                num_valid += 1

Mufei Li's avatar
Mufei Li committed
159
            if len(adj_lists_to_plot) >= 4:
160
161
162
163
                plot_times += 1
                fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
                axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
                for i in range(4):
164
165
166
167
168
                    nx.draw_circular(
                        nx.from_dict_of_lists(adj_lists_to_plot[i]),
                        with_labels=True,
                        ax=axes[i],
                    )
169

170
                plt.savefig(self.dir + "/samples/{:d}".format(plot_times))
171
172
173
174
175
176
177
178
179
180
181
182
183
                plt.close()

                adj_lists_to_plot = []

        self.num_samples_examined = num_samples
        self.average_size = num_total_size / num_samples
        self.valid_size_ratio = num_valid_size / num_samples
        self.cycle_ratio = num_cycle / num_samples
        self.valid_ratio = num_valid / num_samples

    def write_summary(self):
        def _format_value(v):
            if isinstance(v, float):
184
                return "{:.4f}".format(v)
185
            elif isinstance(v, int):
186
                return "{:d}".format(v)
187
            else:
188
                return "{}".format(v)
189
190

        statistics = {
191
192
193
194
195
196
197
            "num_samples": self.num_samples_examined,
            "v_min": self.v_min,
            "v_max": self.v_max,
            "average_size": self.average_size,
            "valid_size_ratio": self.valid_size_ratio,
            "cycle_ratio": self.cycle_ratio,
            "valid_ratio": self.valid_ratio,
198
199
        }

200
        model_eval_path = os.path.join(self.dir, "model_eval.txt")
201

202
        with open(model_eval_path, "w") as f:
203
            for key, value in statistics.items():
204
                msg = "{}\t{}\n".format(key, _format_value(value))
205
206
                f.write(msg)

207
        print("Saved model evaluation statistics to {}".format(model_eval_path))
208
209
210
211
212
213
214
215
216
217
218
219
220


class CyclePrinting(object):
    def __init__(self, num_epochs, num_batches):
        super(CyclePrinting, self).__init__()

        self.num_epochs = num_epochs
        self.num_batches = num_batches
        self.batch_count = 0

    def update(self, epoch, metrics):
        self.batch_count = (self.batch_count) % self.num_batches + 1

221
222
223
        msg = "epoch {:d}/{:d}, batch {:d}/{:d}".format(
            epoch, self.num_epochs, self.batch_count, self.num_batches
        )
224
        for key, value in metrics.items():
225
            msg += ", {}: {:4f}".format(key, value)
226
        print(msg)