import argparse import time import traceback import networkx as nx import numpy as np import torch from dataloader import ( MultiBodyGraphCollator, MultiBodyTestDataset, MultiBodyTrainDataset, MultiBodyValidDataset, ) from models import MLP, InteractionNet, PrepareLayer from torch.utils.data import DataLoader from utils import make_video import dgl def train( optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device ): total_loss = 0 model.train() for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader): graph_batch = graph_batch.to(device) data_batch = data_batch.to(device) label_batch = label_batch.to(device) optimizer.zero_grad() node_feat, edge_feat = prep(graph_batch, data_batch) dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device) v_pred, out_e = model( graph_batch, node_feat[:, 3:5].float(), edge_feat.float(), dummy_global, dummy_relation, ) loss = loss_fn(v_pred, label_batch) total_loss += float(loss) zero_target = torch.zeros_like(out_e) loss = loss + lambda_reg * reg_fn(out_e, zero_target) reg_loss = 0 for param in model.parameters(): reg_loss = reg_loss + lambda_reg * reg_fn( param, torch.zeros_like(param).float().to(device) ) loss = loss + reg_loss loss.backward() optimizer.step() return total_loss / (i + 1) # One step evaluation def eval(loss_fn, model, prep, dataloader, device): total_loss = 0 model.eval() for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader): graph_batch = graph_batch.to(device) data_batch = data_batch.to(device) label_batch = label_batch.to(device) node_feat, edge_feat = prep(graph_batch, data_batch) dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device) v_pred, _ = model( graph_batch, node_feat[:, 3:5].float(), edge_feat.float(), dummy_global, dummy_relation, ) loss = loss_fn(v_pred, label_batch) total_loss += float(loss) return total_loss / (i + 1) # Rollout Evaluation based in initial state # Need to integrate def eval_rollout(model, prep, initial_frame, n_object, device): current_frame = initial_frame.to(device) base_graph = nx.complete_graph(n_object) graph = dgl.from_networkx(base_graph).to(device) pos_buffer = [] model.eval() for step in range(100): node_feats, edge_feats = prep(graph, current_frame) dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device) v_pred, _ = model( graph, node_feats[:, 3:5].float(), edge_feats.float(), dummy_global, dummy_relation, ) current_frame[:, [1, 2]] += v_pred * 0.001 current_frame[:, 3:5] = v_pred pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy()) pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1) make_video(pos_buffer, "video_model.mp4") if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument( "--lr", type=float, default=0.001, help="learning rate" ) argparser.add_argument( "--epochs", type=int, default=40000, help="Number of epochs in training" ) argparser.add_argument( "--lambda_reg", type=float, default=0.001, help="regularization weight" ) argparser.add_argument( "--gpu", type=int, default=-1, help="gpu device code, -1 means cpu" ) argparser.add_argument( "--batch_size", type=int, default=100, help="size of each mini batch" ) argparser.add_argument( "--num_workers", type=int, default=0, help="number of workers for dataloading", ) argparser.add_argument( "--visualize", action="store_true", default=False, help="Whether enable trajectory rollout mode for visualization", ) args = argparser.parse_args() # Select Device to be CPU or GPU if args.gpu != -1: device = torch.device("cuda:{}".format(args.gpu)) else: device = torch.device("cpu") train_data = MultiBodyTrainDataset() valid_data = MultiBodyValidDataset() test_data = MultiBodyTestDataset() collator = MultiBodyGraphCollator(train_data.n_particles) train_dataloader = DataLoader( train_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers, ) valid_dataloader = DataLoader( valid_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers, ) test_full_dataloader = DataLoader( test_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers, ) node_feats = 5 stat = { "median": torch.from_numpy(train_data.stat_median).to(device), "max": torch.from_numpy(train_data.stat_max).to(device), "min": torch.from_numpy(train_data.stat_min).to(device), } print( "Weight: ", train_data.stat_median[0], train_data.stat_max[0], train_data.stat_min[0], ) print( "Position: ", train_data.stat_median[[1, 2]], train_data.stat_max[[1, 2]], train_data.stat_min[[1, 2]], ) print( "Velocity: ", train_data.stat_median[[3, 4]], train_data.stat_max[[3, 4]], train_data.stat_min[[3, 4]], ) prepare_layer = PrepareLayer(node_feats, stat).to(device) interaction_net = InteractionNet(node_feats, stat).to(device) print(interaction_net) optimizer = torch.optim.Adam(interaction_net.parameters(), lr=args.lr) state_dict = interaction_net.state_dict() loss_fn = torch.nn.MSELoss() reg_fn = torch.nn.MSELoss(reduction="sum") try: for e in range(args.epochs): last_t = time.time() loss = train( optimizer, loss_fn, reg_fn, interaction_net, prepare_layer, train_dataloader, args.lambda_reg, device, ) print("Epoch time: ", time.time() - last_t) if e % 1 == 0: valid_loss = eval( loss_fn, interaction_net, prepare_layer, valid_dataloader, device, ) test_full_loss = eval( loss_fn, interaction_net, prepare_layer, test_full_dataloader, device, ) print( "Epoch: {}.Loss: Valid: {} Full: {}".format( e, valid_loss, test_full_loss ) ) except: traceback.print_exc() finally: if args.visualize: eval_rollout( interaction_net, prepare_layer, test_data.first_frame, test_data.n_particles, device, ) make_video(test_data.test_traj[:100, :, [1, 2]], "video_truth.mp4")