""" Learning Deep Generative Models of Graphs Paper: https://arxiv.org/pdf/1803.03324.pdf This implementation works with a minibatch of size 1 only for both training and inference. """ import argparse import datetime import time import torch from model import DGMG from torch.nn.utils import clip_grad_norm_ from torch.optim import Adam from torch.utils.data import DataLoader def main(opts): t1 = time.time() # Setup dataset and data loader if opts["dataset"] == "cycles": from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting dataset = CycleDataset(fname=opts["path_to_dataset"]) evaluator = CycleModelEvaluation( v_min=opts["min_size"], v_max=opts["max_size"], dir=opts["log_dir"] ) printer = CyclePrinting( num_epochs=opts["nepochs"], num_batches=opts["ds_size"] // opts["batch_size"], ) else: raise ValueError("Unsupported dataset: {}".format(opts["dataset"])) data_loader = DataLoader( dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=dataset.collate_single, ) # Initialize_model model = DGMG( v_max=opts["max_size"], node_hidden_size=opts["node_hidden_size"], num_prop_rounds=opts["num_propagation_rounds"], ) # Initialize optimizer if opts["optimizer"] == "Adam": optimizer = Adam(model.parameters(), lr=opts["lr"]) else: raise ValueError("Unsupported argument for the optimizer") t2 = time.time() # Training model.train() for epoch in range(opts["nepochs"]): batch_count = 0 batch_loss = 0 batch_prob = 0 optimizer.zero_grad() for i, data in enumerate(data_loader): log_prob = model(actions=data) prob = log_prob.detach().exp() loss = -log_prob / opts["batch_size"] prob_averaged = prob / opts["batch_size"] loss.backward() batch_loss += loss.item() batch_prob += prob_averaged.item() batch_count += 1 if batch_count % opts["batch_size"] == 0: printer.update( epoch + 1, {"averaged_loss": batch_loss, "averaged_prob": batch_prob}, ) if opts["clip_grad"]: clip_grad_norm_(model.parameters(), opts["clip_bound"]) optimizer.step() batch_loss = 0 batch_prob = 0 optimizer.zero_grad() t3 = time.time() model.eval() evaluator.rollout_and_examine(model, opts["num_generated_samples"]) evaluator.write_summary() t4 = time.time() print("It took {} to setup.".format(datetime.timedelta(seconds=t2 - t1))) print( "It took {} to finish training.".format( datetime.timedelta(seconds=t3 - t2) ) ) print( "It took {} to finish evaluation.".format( datetime.timedelta(seconds=t4 - t3) ) ) print( "--------------------------------------------------------------------------" ) print( "On average, an epoch takes {}.".format( datetime.timedelta(seconds=(t3 - t2) / opts["nepochs"]) ) ) del model.g torch.save(model, "./model.pth") if __name__ == "__main__": parser = argparse.ArgumentParser(description="DGMG") # configure parser.add_argument("--seed", type=int, default=9284, help="random seed") # dataset parser.add_argument( "--dataset", choices=["cycles"], default="cycles", help="dataset to use" ) parser.add_argument( "--path-to-dataset", type=str, default="cycles.p", help="load the dataset if it exists, " "generate it and save to the path otherwise", ) # log parser.add_argument( "--log-dir", default="./results", help="folder to save info like experiment configuration " "or model evaluation results", ) # optimization parser.add_argument( "--batch-size", type=int, default=10, help="batch size to use for training", ) parser.add_argument( "--clip-grad", action="store_true", default=True, help="gradient clipping is required to prevent gradient explosion", ) parser.add_argument( "--clip-bound", type=float, default=0.25, help="constraint of gradient norm for gradient clipping", ) args = parser.parse_args() from utils import setup opts = setup(args) main(opts)