import matplotlib.pyplot as plt import networkx as nx import os import pickle import random from torch.utils.data import Dataset def get_previous(i, v_max): if i == 0: return v_max else: return i - 1 def get_next(i, v_max): if i == v_max: return 0 else: return i + 1 def is_cycle(g): size = g.number_of_nodes() if size < 3: return False for node in range(size): neighbors = g.successors(node) if len(neighbors) != 2: return False if get_previous(node, size - 1) not in neighbors: return False if get_next(node, size - 1) not in neighbors: return False return True def get_decision_sequence(size): """ Get the decision sequence for generating valid cycles with DGMG for teacher forcing optimization. """ decision_sequence = [] for i in range(size): decision_sequence.append(0) # Add node if i != 0: decision_sequence.append(0) # Add edge decision_sequence.append(i - 1) # Set destination to be previous node. if i == size - 1: decision_sequence.append(0) # Add edge decision_sequence.append(0) # Set destination to be the root. decision_sequence.append(1) # Stop adding edge decision_sequence.append(1) # Stop adding node return decision_sequence def generate_dataset(v_min, v_max, n_samples, fname): samples = [] for _ in range(n_samples): size = random.randint(v_min, v_max) samples.append(get_decision_sequence(size)) with open(fname, 'wb') as f: pickle.dump(samples, f) class CycleDataset(Dataset): def __init__(self, fname): super(CycleDataset, self).__init__() with open(fname, 'rb') as f: self.dataset = pickle.load(f) def __len__(self): return len(self.dataset) def __getitem__(self, index): return self.dataset[index] def collate_single(self, batch): assert len(batch) == 1, 'Currently we do not support batched training' return batch[0] def collate_batch(self, batch): return batch def dglGraph_to_adj_list(g): adj_list = {} for node in range(g.number_of_nodes()): # For undirected graph. successors and # predecessors are equivalent. adj_list[node] = g.successors(node).tolist() return adj_list class CycleModelEvaluation(object): def __init__(self, v_min, v_max, dir): super(CycleModelEvaluation, self).__init__() self.v_min = v_min self.v_max = v_max self.dir = dir def rollout_and_examine(self, model, num_samples): assert not model.training, 'You need to call model.eval().' num_total_size = 0 num_valid_size = 0 num_cycle = 0 num_valid = 0 plot_times = 0 adj_lists_to_plot = [] for i in range(num_samples): sampled_graph = model() if isinstance(sampled_graph, list): # When the model is a batched implementation, a list of # DGLGraph objects is returned. Note that with model(), # we generate a single graph as with the non-batched # implementation. We actually support batched generation # during the inference so feel free to modify the code. sampled_graph = sampled_graph[0] sampled_adj_list = dglGraph_to_adj_list(sampled_graph) adj_lists_to_plot.append(sampled_adj_list) graph_size = sampled_graph.number_of_nodes() valid_size = (self.v_min <= graph_size <= self.v_max) cycle = is_cycle(sampled_graph) num_total_size += graph_size if valid_size: num_valid_size += 1 if cycle: num_cycle += 1 if valid_size and cycle: num_valid += 1 if len(adj_lists_to_plot) >= 4: plot_times += 1 fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2) axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3} for i in range(4): nx.draw_circular(nx.from_dict_of_lists(adj_lists_to_plot[i]), with_labels=True, ax=axes[i]) plt.savefig(self.dir + '/samples/{:d}'.format(plot_times)) plt.close() adj_lists_to_plot = [] self.num_samples_examined = num_samples self.average_size = num_total_size / num_samples self.valid_size_ratio = num_valid_size / num_samples self.cycle_ratio = num_cycle / num_samples self.valid_ratio = num_valid / num_samples def write_summary(self): def _format_value(v): if isinstance(v, float): return '{:.4f}'.format(v) elif isinstance(v, int): return '{:d}'.format(v) else: return '{}'.format(v) statistics = { 'num_samples': self.num_samples_examined, 'v_min': self.v_min, 'v_max': self.v_max, 'average_size': self.average_size, 'valid_size_ratio': self.valid_size_ratio, 'cycle_ratio': self.cycle_ratio, 'valid_ratio': self.valid_ratio } model_eval_path = os.path.join(self.dir, 'model_eval.txt') with open(model_eval_path, 'w') as f: for key, value in statistics.items(): msg = '{}\t{}\n'.format(key, _format_value(value)) f.write(msg) print('Saved model evaluation statistics to {}'.format(model_eval_path)) class CyclePrinting(object): def __init__(self, num_epochs, num_batches): super(CyclePrinting, self).__init__() self.num_epochs = num_epochs self.num_batches = num_batches self.batch_count = 0 def update(self, epoch, metrics): self.batch_count = (self.batch_count) % self.num_batches + 1 msg = 'epoch {:d}/{:d}, batch {:d}/{:d}'.format(epoch, self.num_epochs, self.batch_count, self.num_batches) for key, value in metrics.items(): msg += ', {}: {:4f}'.format(key, value) print(msg)