import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import dgl from dgl.data.utils import download, get_download_dir from functools import partial import tqdm import urllib import os import argparse import time from ShapeNet import ShapeNet from pointnet_partseg import PointNetPartSeg, PartSegLoss from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='pointnet') parser.add_argument('--dataset-path', type=str, default='') parser.add_argument('--load-model-path', type=str, default='') parser.add_argument('--save-model-path', type=str, default='') parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument('--num-workers', type=int, default=4) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--tensorboard', action='store_true') args = parser.parse_args() num_workers = args.num_workers batch_size = args.batch_size def collate(samples): graphs, cat = map(list, zip(*samples)) return dgl.batch(graphs), cat CustomDataLoader = partial( DataLoader, num_workers=num_workers, batch_size=batch_size, shuffle=True, drop_last=True) def train(net, opt, scheduler, train_loader, dev): category_list = sorted(list(shapenet.seg_classes.keys())) eye_mat = np.eye(16) net.train() total_loss = 0 num_batches = 0 total_correct = 0 count = 0 start = time.time() with tqdm.tqdm(train_loader, ascii=True) as tq: for data, label, cat in tq: num_examples = data.shape[0] data = data.to(dev, dtype=torch.float) label = label.to(dev, dtype=torch.long).view(-1) opt.zero_grad() cat_ind = [category_list.index(c) for c in cat] # An one-hot encoding for the object category cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048) cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1) logits = net(data, cat_tensor) loss = L(logits, label) loss.backward() opt.step() _, preds = logits.max(1) count += num_examples * 2048 loss = loss.item() total_loss += loss num_batches += 1 correct = (preds.view(-1) == label).sum().item() total_correct += correct AvgLoss = total_loss / num_batches AvgAcc = total_correct / count tq.set_postfix({ 'AvgLoss': '%.5f' % AvgLoss, 'AvgAcc': '%.5f' % AvgAcc}) scheduler.step() end = time.time() return data, preds, AvgLoss, AvgAcc, end-start def mIoU(preds, label, cat, cat_miou, seg_classes): for i in range(preds.shape[0]): shape_iou = 0 n = len(seg_classes[cat[i]]) for cls in seg_classes[cat[i]]: pred_set = set(np.where(preds[i,:] == cls)[0]) label_set = set(np.where(label[i,:] == cls)[0]) union = len(pred_set.union(label_set)) inter = len(pred_set.intersection(label_set)) if union == 0: shape_iou += 1 else: shape_iou += inter / union shape_iou /= n cat_miou[cat[i]][0] += shape_iou cat_miou[cat[i]][1] += 1 return cat_miou def evaluate(net, test_loader, dev, per_cat_verbose=False): category_list = sorted(list(shapenet.seg_classes.keys())) eye_mat = np.eye(16) net.eval() cat_miou = {} for k in shapenet.seg_classes.keys(): cat_miou[k] = [0, 0] miou = 0 count = 0 per_cat_miou = 0 per_cat_count = 0 with torch.no_grad(): with tqdm.tqdm(test_loader, ascii=True) as tq: for data, label, cat in tq: num_examples = data.shape[0] data = data.to(dev, dtype=torch.float) label = label.to(dev, dtype=torch.long) cat_ind = [category_list.index(c) for c in cat] cat_tensor = torch.tensor(eye_mat[cat_ind]).to(dev, dtype=torch.float).repeat(1, 2048) cat_tensor = cat_tensor.view(num_examples, -1, 16).permute(0,2,1) logits = net(data, cat_tensor) _, preds = logits.max(1) cat_miou = mIoU(preds.cpu().numpy(), label.view(num_examples, -1).cpu().numpy(), cat, cat_miou, shapenet.seg_classes) for _, v in cat_miou.items(): if v[1] > 0: miou += v[0] count += v[1] per_cat_miou += v[0] / v[1] per_cat_count += 1 tq.set_postfix({ 'mIoU': '%.5f' % (miou / count), 'per Category mIoU': '%.5f' % (miou / count)}) if per_cat_verbose: print("Per-Category mIoU:") for k, v in cat_miou.items(): if v[1] > 0: print("%s mIoU=%.5f" % (k, v[0] / v[1])) else: print("%s mIoU=%.5f" % (k, 1)) return miou / count, per_cat_miou / per_cat_count dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") # dev = "cpu" if args.model == 'pointnet': net = PointNetPartSeg(50, 3, 2048) elif args.model == 'pointnet2_ssg': net = PointNet2SSGPartSeg(50, batch_size, input_dims=6) elif args.model == 'pointnet2_msg': net = PointNet2MSGPartSeg(50, batch_size, input_dims=6) net = net.to(dev) if args.load_model_path: net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) opt = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4) scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.5) L = PartSegLoss() shapenet = ShapeNet(2048, normal_channel=False) train_loader = CustomDataLoader(shapenet.trainval()) test_loader = CustomDataLoader(shapenet.test()) # Tensorboard if args.tensorboard: import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision import datasets, transforms writer = SummaryWriter() # Select 50 distinct colors for different parts color_map = torch.tensor([ [47, 79, 79],[139, 69, 19],[112, 128, 144],[85, 107, 47],[139, 0, 0],[128, 128, 0],[72, 61, 139],[0, 128, 0],[188, 143, 143],[60, 179, 113], [205, 133, 63],[0, 139, 139],[70, 130, 180],[205, 92, 92],[154, 205, 50],[0, 0, 139],[50, 205, 50],[250, 250, 250],[218, 165, 32],[139, 0, 139], [10, 10, 10],[176, 48, 96],[72, 209, 204],[153, 50, 204],[255, 69, 0],[255, 145, 0],[0, 0, 205],[255, 255, 0],[0, 255, 0],[233, 150, 122], [220, 20, 60],[0, 191, 255],[160, 32, 240],[192,192,192],[173, 255, 47],[218, 112, 214],[216, 191, 216],[255, 127, 80],[255, 0, 255],[100, 149, 237], [128,128,128],[221, 160, 221],[144, 238, 144],[123, 104, 238],[255, 160, 122],[175, 238, 238],[238, 130, 238],[127, 255, 212],[255, 218, 185],[255, 105, 180], ]) # paint each point according to its pred def paint(batched_points): B, N = batched_points.shape colored = color_map[batched_points].squeeze(2) return colored best_test_miou = 0 best_test_per_cat_miou = 0 for epoch in range(args.num_epochs): data, preds, AvgLoss, AvgAcc, training_time = train(net, opt, scheduler, train_loader, dev) if (epoch + 1) % 5 == 0: print('Epoch #%d Testing' % epoch) test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, (epoch + 1) % 5 ==0) if test_miou > best_test_miou: best_test_miou = test_miou best_test_per_cat_miou = test_per_cat_miou if args.save_model_path: torch.save(net.state_dict(), args.save_model_path) print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) # Tensorboard if args.tensorboard: colored = paint(preds) writer.add_mesh('data', vertices=data, colors=colored, global_step=epoch) writer.add_scalar('training time for one epoch', training_time, global_step=epoch) writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch) writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch) if (epoch + 1) % 5 == 0: writer.add_scalar('test mIoU', test_miou, global_step=epoch) writer.add_scalar('best test mIoU', best_test_miou, global_step=epoch)