train_cls.py 5.22 KB
Newer Older
esang's avatar
esang committed
1
2
import argparse
import os
3
import time
esang's avatar
esang committed
4
from functools import partial
5
6

import provider
esang's avatar
esang committed
7
import torch
8
9
import torch.nn as nn
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
11

from dgl.data.utils import download, get_download_dir
12
13
14
15
from ModelNetDataLoader import ModelNetDataLoader
from pct import PointTransformerCLS
from torch.utils.data import DataLoader

esang's avatar
esang committed
16
17
18
torch.backends.cudnn.enabled = False

parser = argparse.ArgumentParser()
19
20
21
22
23
24
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=32)
esang's avatar
esang committed
25
26
27
28
29
args = parser.parse_args()

num_workers = args.num_workers
batch_size = args.batch_size

30
data_filename = "modelnet40_normal_resampled.zip"
esang's avatar
esang committed
31
32
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
33
34
    get_download_dir(), "modelnet40_normal_resampled"
)
esang's avatar
esang committed
35
36

if not os.path.exists(local_path):
37
38
39
40
41
    download(
        "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
        download_path,
        verify_ssl=False,
    )
esang's avatar
esang committed
42
    from zipfile import ZipFile
43

esang's avatar
esang committed
44
45
46
47
48
49
50
51
    with ZipFile(download_path) as z:
        z.extractall(path=get_download_dir())

CustomDataLoader = partial(
    DataLoader,
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=True,
52
53
    drop_last=True,
)
esang's avatar
esang committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


def train(net, opt, scheduler, train_loader, dev):
    net.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    loss_f = nn.CrossEntropyLoss()
    start_time = time.time()
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label in tq:
            data = data.data.numpy()
            data = provider.random_point_dropout(data)
69
            data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
esang's avatar
esang committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
            data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
            data = torch.tensor(data)
            label = label[:, 0]

            num_examples = label.shape[0]
            data, label = data.to(dev), label.to(dev).squeeze().long()
            opt.zero_grad()
            logits = net(data)
            loss = loss_f(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            num_batches += 1
            count += num_examples
            loss = loss.item()
            correct = (preds == label).sum().item()
            total_loss += loss
            total_correct += correct

92
93
94
95
96
97
98
99
100
101
102
103
104
            tq.set_postfix(
                {
                    "AvgLoss": "%.5f" % (total_loss / num_batches),
                    "AvgAcc": "%.5f" % (total_correct / count),
                }
            )
    print(
        "[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
            total_loss / num_batches,
            total_correct / count,
            time.time() - start_time,
        )
    )
esang's avatar
esang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    scheduler.step()


def evaluate(net, test_loader, dev):
    net.eval()

    total_correct = 0
    count = 0
    start_time = time.time()
    with torch.no_grad():
        with tqdm.tqdm(test_loader, ascii=True) as tq:
            for data, label in tq:
                label = label[:, 0]
                num_examples = label.shape[0]
                data, label = data.to(dev), label.to(dev).squeeze().long()
                logits = net(data)
                _, preds = logits.max(1)

                correct = (preds == label).sum().item()
                total_correct += correct
                count += num_examples

127
128
129
130
131
132
                tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
    print(
        "[Test]  AvgAcc: {:.5}, Time: {:.5}s".format(
            total_correct / count, time.time() - start_time
        )
    )
esang's avatar
esang committed
133
134
135
136
137
138
139
140
141
142
143
144
    return total_correct / count


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerCLS()

net = net.to(dev)
if args.load_model_path:
    net.load_state_dict(torch.load(args.load_model_path, map_location=dev))


opt = torch.optim.SGD(
145
    net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
esang's avatar
esang committed
146
147
148
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
149
150
    opt, T_max=args.num_epochs
)
esang's avatar
esang committed
151

152
153
train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
esang's avatar
esang committed
154
train_loader = torch.utils.data.DataLoader(
155
156
157
158
159
160
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)
esang's avatar
esang committed
161
test_loader = torch.utils.data.DataLoader(
162
163
164
165
166
167
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=True,
)
esang's avatar
esang committed
168
169
170
171
172
173
174
175
176
177
178
179

best_test_acc = 0

for epoch in range(args.num_epochs):
    print("Epoch #{}: ".format(epoch))
    train(net, opt, scheduler, train_loader, dev)
    if (epoch + 1) % 1 == 0:
        test_acc = evaluate(net, test_loader, dev)
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            if args.save_model_path:
                torch.save(net.state_dict(), args.save_model_path)
180
        print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
esang's avatar
esang committed
181
    print()