import torch import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import dgl from functools import partial import tqdm import argparse import time import provider from ShapeNet import ShapeNet from pct import PointTransformerSeg, PartSegLoss parser = argparse.ArgumentParser() parser.add_argument('--dataset-path', type=str, default='') parser.add_argument('--load-model-path', type=str, default='') parser.add_argument('--save-model-path', type=str, default='') parser.add_argument('--num-epochs', type=int, default=500) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--tensorboard', action='store_true') args = parser.parse_args() num_workers = args.num_workers batch_size = args.batch_size def collate(samples): graphs, cat = map(list, zip(*samples)) return dgl.batch(graphs), cat CustomDataLoader = partial( DataLoader, num_workers=num_workers, batch_size=batch_size, shuffle=True, drop_last=True) def train(net, opt, scheduler, train_loader, dev): category_list = sorted(list(shapenet.seg_classes.keys())) eye_mat = np.eye(16) net.train() total_loss = 0 num_batches = 0 total_correct = 0 count = 0 start = time.time() with tqdm.tqdm(train_loader, ascii=True) as tq: for data, label, cat in tq: num_examples = data.shape[0] data = data.to(dev, dtype=torch.float) label = label.to(dev, dtype=torch.long).view(-1) opt.zero_grad() cat_ind = [category_list.index(c) for c in cat] # An one-hot encoding for the object category cat_tensor = torch.tensor(eye_mat[cat_ind]).to( dev, dtype=torch.float) cat_tensor = cat_tensor.view(num_examples, 16, 1) logits = net(data, cat_tensor) loss = L(logits, label) loss.backward() opt.step() _, preds = logits.max(1) count += num_examples * 2048 loss = loss.item() total_loss += loss num_batches += 1 correct = (preds.view(-1) == label).sum().item() total_correct += correct AvgLoss = total_loss / num_batches AvgAcc = total_correct / count tq.set_postfix({ 'AvgLoss': '%.5f' % AvgLoss, 'AvgAcc': '%.5f' % AvgAcc}) scheduler.step() end = time.time() print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / num_batches, total_correct / count, end - start)) return data, preds, AvgLoss, AvgAcc, end-start def mIoU(preds, label, cat, cat_miou, seg_classes): for i in range(preds.shape[0]): shape_iou = 0 n = len(seg_classes[cat[i]]) for cls in seg_classes[cat[i]]: pred_set = set(np.where(preds[i, :] == cls)[0]) label_set = set(np.where(label[i, :] == cls)[0]) union = len(pred_set.union(label_set)) inter = len(pred_set.intersection(label_set)) if union == 0: shape_iou += 1 else: shape_iou += inter / union shape_iou /= n cat_miou[cat[i]][0] += shape_iou cat_miou[cat[i]][1] += 1 return cat_miou def evaluate(net, test_loader, dev, per_cat_verbose=False): category_list = sorted(list(shapenet.seg_classes.keys())) eye_mat = np.eye(16) net.eval() cat_miou = {} for k in shapenet.seg_classes.keys(): cat_miou[k] = [0, 0] miou = 0 count = 0 per_cat_miou = 0 per_cat_count = 0 with torch.no_grad(): with tqdm.tqdm(test_loader, ascii=True) as tq: for data, label, cat in tq: num_examples = data.shape[0] data = data.to(dev, dtype=torch.float) label = label.to(dev, dtype=torch.long) cat_ind = [category_list.index(c) for c in cat] cat_tensor = torch.tensor(eye_mat[cat_ind]).to( dev, dtype=torch.float) cat_tensor = cat_tensor.view( num_examples, 16, 1) logits = net(data, cat_tensor) _, preds = logits.max(1) cat_miou = mIoU(preds.cpu().numpy(), label.view(num_examples, -1).cpu().numpy(), cat, cat_miou, shapenet.seg_classes) for _, v in cat_miou.items(): if v[1] > 0: miou += v[0] count += v[1] per_cat_miou += v[0] / v[1] per_cat_count += 1 tq.set_postfix({ 'mIoU': '%.5f' % (miou / count), 'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)}) print("[Test] mIoU: %.5f, per Category mIoU: %.5f" % (miou / count, per_cat_miou / per_cat_count)) if per_cat_verbose: print("-" * 60) print("Per-Category mIoU:") for k, v in cat_miou.items(): if v[1] > 0: print("%s mIoU=%.5f" % (k, v[0] / v[1])) else: print("%s mIoU=%.5f" % (k, 1)) print("-" * 60) return miou / count, per_cat_miou / per_cat_count dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = PointTransformerSeg() net = net.to(dev) if args.load_model_path: net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) opt = torch.optim.SGD( net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=args.num_epochs) L = PartSegLoss() shapenet = ShapeNet(2048, normal_channel=False) train_loader = CustomDataLoader(shapenet.trainval()) test_loader = CustomDataLoader(shapenet.test()) # Tensorboard if args.tensorboard: import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision import datasets, transforms writer = SummaryWriter() # Select 50 distinct colors for different parts color_map = torch.tensor([ [47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [ 128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113], [205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [ 0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139], [10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [ 255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122], [220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [ 218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237], [128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [ 175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180], ]) # paint each point according to its pred def paint(batched_points): B, N = batched_points.shape colored = color_map[batched_points].squeeze(2) return colored best_test_miou = 0 best_test_per_cat_miou = 0 for epoch in range(args.num_epochs): print("Epoch #{}: ".format(epoch)) data, preds, AvgLoss, AvgAcc, training_time = train( net, opt, scheduler, train_loader, dev) if (epoch + 1) % 5 == 0 or epoch == 0: test_miou, test_per_cat_miou = evaluate( net, test_loader, dev, True) if test_miou > best_test_miou: best_test_miou = test_miou best_test_per_cat_miou = test_per_cat_miou if args.save_model_path: torch.save(net.state_dict(), args.save_model_path) print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) # Tensorboard if args.tensorboard: colored = paint(preds) writer.add_mesh('data', vertices=data, colors=colored, global_step=epoch) writer.add_scalar('training time for one epoch', training_time, global_step=epoch) writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch) writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch) if (epoch + 1) % 5 == 0: writer.add_scalar('test mIoU', test_miou, global_step=epoch) writer.add_scalar('best test mIoU', best_test_miou, global_step=epoch) print()