from pct import PointTransformerCLS from ModelNetDataLoader import ModelNetDataLoader import provider import argparse import os import tqdm from functools import partial from dgl.data.utils import download, get_download_dir from torch.utils.data import DataLoader import torch.nn as nn import torch import time torch.backends.cudnn.enabled = False parser = argparse.ArgumentParser() parser.add_argument('--dataset-path', type=str, default='') parser.add_argument('--load-model-path', type=str, default='') parser.add_argument('--save-model-path', type=str, default='') parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument('--num-workers', type=int, default=8) parser.add_argument('--batch-size', type=int, default=32) args = parser.parse_args() num_workers = args.num_workers batch_size = args.batch_size data_filename = 'modelnet40_normal_resampled.zip' download_path = os.path.join(get_download_dir(), data_filename) local_path = args.dataset_path or os.path.join( get_download_dir(), 'modelnet40_normal_resampled') if not os.path.exists(local_path): download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download_path, verify_ssl=False) from zipfile import ZipFile with ZipFile(download_path) as z: z.extractall(path=get_download_dir()) CustomDataLoader = partial( DataLoader, num_workers=num_workers, batch_size=batch_size, shuffle=True, drop_last=True) def train(net, opt, scheduler, train_loader, dev): net.train() total_loss = 0 num_batches = 0 total_correct = 0 count = 0 loss_f = nn.CrossEntropyLoss() start_time = time.time() with tqdm.tqdm(train_loader, ascii=True) as tq: for data, label in tq: data = data.data.numpy() data = provider.random_point_dropout(data) data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data = torch.tensor(data) label = label[:, 0] num_examples = label.shape[0] data, label = data.to(dev), label.to(dev).squeeze().long() opt.zero_grad() logits = net(data) loss = loss_f(logits, label) loss.backward() opt.step() _, preds = logits.max(1) num_batches += 1 count += num_examples loss = loss.item() correct = (preds == label).sum().item() total_loss += loss total_correct += correct tq.set_postfix({ 'AvgLoss': '%.5f' % (total_loss / num_batches), 'AvgAcc': '%.5f' % (total_correct / count)}) print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / num_batches, total_correct / count, time.time() - start_time)) scheduler.step() def evaluate(net, test_loader, dev): net.eval() total_correct = 0 count = 0 start_time = time.time() with torch.no_grad(): with tqdm.tqdm(test_loader, ascii=True) as tq: for data, label in tq: label = label[:, 0] num_examples = label.shape[0] data, label = data.to(dev), label.to(dev).squeeze().long() logits = net(data) _, preds = logits.max(1) correct = (preds == label).sum().item() total_correct += correct count += num_examples tq.set_postfix({ 'AvgAcc': '%.5f' % (total_correct / count)}) print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format( total_correct / count, time.time() - start_time)) return total_correct / count dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = PointTransformerCLS() net = net.to(dev) if args.load_model_path: net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) opt = torch.optim.SGD( net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=args.num_epochs) train_dataset = ModelNetDataLoader(local_path, 1024, split='train') test_dataset = ModelNetDataLoader(local_path, 1024, split='test') train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) best_test_acc = 0 for epoch in range(args.num_epochs): print("Epoch #{}: ".format(epoch)) train(net, opt, scheduler, train_loader, dev) if (epoch + 1) % 1 == 0: test_acc = evaluate(net, test_loader, dev) if test_acc > best_test_acc: best_test_acc = test_acc if args.save_model_path: torch.save(net.state_dict(), args.save_model_path) print('Current test acc: %.5f (best: %.5f)' % ( test_acc, best_test_acc)) print()