train_partseg.py 9.09 KB
Newer Older
esang's avatar
esang committed
1
2
import argparse
import time
3
from functools import partial
esang's avatar
esang committed
4

5
import numpy as np
esang's avatar
esang committed
6
import provider
7
8
9
10
import torch
import torch.optim as optim
import tqdm
from pct import PartSegLoss, PointTransformerSeg
esang's avatar
esang committed
11
from ShapeNet import ShapeNet
12
13
14
from torch.utils.data import DataLoader

import dgl
esang's avatar
esang committed
15
16

parser = argparse.ArgumentParser()
17
18
19
20
21
22
23
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=500)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--tensorboard", action="store_true")
esang's avatar
esang committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
args = parser.parse_args()

num_workers = args.num_workers
batch_size = args.batch_size


def collate(samples):
    graphs, cat = map(list, zip(*samples))
    return dgl.batch(graphs), cat


CustomDataLoader = partial(
    DataLoader,
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=True,
40
41
    drop_last=True,
)
esang's avatar
esang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def train(net, opt, scheduler, train_loader, dev):
    category_list = sorted(list(shapenet.seg_classes.keys()))
    eye_mat = np.eye(16)
    net.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    start = time.time()
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label, cat in tq:
            num_examples = data.shape[0]
            data = data.to(dev, dtype=torch.float)
            label = label.to(dev, dtype=torch.long).view(-1)
            opt.zero_grad()
            cat_ind = [category_list.index(c) for c in cat]
            # An one-hot encoding for the object category
            cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
63
64
                dev, dtype=torch.float
            )
esang's avatar
esang committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            cat_tensor = cat_tensor.view(num_examples, 16, 1)
            logits = net(data, cat_tensor)
            loss = L(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            count += num_examples * 2048
            loss = loss.item()
            total_loss += loss
            num_batches += 1
            correct = (preds.view(-1) == label).sum().item()
            total_correct += correct

            AvgLoss = total_loss / num_batches
            AvgAcc = total_correct / count

83
84
85
            tq.set_postfix(
                {"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
            )
esang's avatar
esang committed
86
87
    scheduler.step()
    end = time.time()
88
89
90
91
92
93
    print(
        "[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
            total_loss / num_batches, total_correct / count, end - start
        )
    )
    return data, preds, AvgLoss, AvgAcc, end - start
esang's avatar
esang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


def mIoU(preds, label, cat, cat_miou, seg_classes):
    for i in range(preds.shape[0]):
        shape_iou = 0
        n = len(seg_classes[cat[i]])
        for cls in seg_classes[cat[i]]:
            pred_set = set(np.where(preds[i, :] == cls)[0])
            label_set = set(np.where(label[i, :] == cls)[0])
            union = len(pred_set.union(label_set))
            inter = len(pred_set.intersection(label_set))
            if union == 0:
                shape_iou += 1
            else:
                shape_iou += inter / union
        shape_iou /= n
        cat_miou[cat[i]][0] += shape_iou
        cat_miou[cat[i]][1] += 1

    return cat_miou


def evaluate(net, test_loader, dev, per_cat_verbose=False):
    category_list = sorted(list(shapenet.seg_classes.keys()))
    eye_mat = np.eye(16)
    net.eval()

    cat_miou = {}
    for k in shapenet.seg_classes.keys():
        cat_miou[k] = [0, 0]
    miou = 0
    count = 0
    per_cat_miou = 0
    per_cat_count = 0

    with torch.no_grad():
        with tqdm.tqdm(test_loader, ascii=True) as tq:
            for data, label, cat in tq:
                num_examples = data.shape[0]
                data = data.to(dev, dtype=torch.float)
                label = label.to(dev, dtype=torch.long)
                cat_ind = [category_list.index(c) for c in cat]
                cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
137
138
139
                    dev, dtype=torch.float
                )
                cat_tensor = cat_tensor.view(num_examples, 16, 1)
esang's avatar
esang committed
140
141
142
                logits = net(data, cat_tensor)
                _, preds = logits.max(1)

143
144
145
146
147
148
149
                cat_miou = mIoU(
                    preds.cpu().numpy(),
                    label.view(num_examples, -1).cpu().numpy(),
                    cat,
                    cat_miou,
                    shapenet.seg_classes,
                )
esang's avatar
esang committed
150
151
152
153
154
155
                for _, v in cat_miou.items():
                    if v[1] > 0:
                        miou += v[0]
                        count += v[1]
                        per_cat_miou += v[0] / v[1]
                        per_cat_count += 1
156
157
158
159
160
161
162
163
164
165
166
                tq.set_postfix(
                    {
                        "mIoU": "%.5f" % (miou / count),
                        "per Category mIoU": "%.5f"
                        % (per_cat_miou / per_cat_count),
                    }
                )
    print(
        "[Test] mIoU: %.5f, per Category mIoU: %.5f"
        % (miou / count, per_cat_miou / per_cat_count)
    )
esang's avatar
esang committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    if per_cat_verbose:
        print("-" * 60)
        print("Per-Category mIoU:")
        for k, v in cat_miou.items():
            if v[1] > 0:
                print("%s mIoU=%.5f" % (k, v[0] / v[1]))
            else:
                print("%s mIoU=%.5f" % (k, 1))
        print("-" * 60)
    return miou / count, per_cat_miou / per_cat_count


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = PointTransformerSeg()

net = net.to(dev)
if args.load_model_path:
    net.load_state_dict(torch.load(args.load_model_path, map_location=dev))

opt = torch.optim.SGD(
187
    net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
esang's avatar
esang committed
188
189
190
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
191
192
    opt, T_max=args.num_epochs
)
esang's avatar
esang committed
193
194
195
196
197
198
199
200
201
202
203
204
205

L = PartSegLoss()

shapenet = ShapeNet(2048, normal_channel=False)

train_loader = CustomDataLoader(shapenet.trainval())
test_loader = CustomDataLoader(shapenet.test())

# Tensorboard
if args.tensorboard:
    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    from torchvision import datasets, transforms
206

esang's avatar
esang committed
207
208
    writer = SummaryWriter()
# Select 50 distinct colors for different parts
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
color_map = torch.tensor(
    [
        [47, 79, 79],
        [139, 69, 19],
        [112, 128, 144],
        [85, 107, 47],
        [139, 0, 0],
        [128, 128, 0],
        [72, 61, 139],
        [0, 128, 0],
        [188, 143, 143],
        [60, 179, 113],
        [205, 133, 63],
        [0, 139, 139],
        [70, 130, 180],
        [205, 92, 92],
        [154, 205, 50],
        [0, 0, 139],
        [50, 205, 50],
        [250, 250, 250],
        [218, 165, 32],
        [139, 0, 139],
        [10, 10, 10],
        [176, 48, 96],
        [72, 209, 204],
        [153, 50, 204],
        [255, 69, 0],
        [255, 145, 0],
        [0, 0, 205],
        [255, 255, 0],
        [0, 255, 0],
        [233, 150, 122],
        [220, 20, 60],
        [0, 191, 255],
        [160, 32, 240],
        [192, 192, 192],
        [173, 255, 47],
        [218, 112, 214],
        [216, 191, 216],
        [255, 127, 80],
        [255, 0, 255],
        [100, 149, 237],
        [128, 128, 128],
        [221, 160, 221],
        [144, 238, 144],
        [123, 104, 238],
        [255, 160, 122],
        [175, 238, 238],
        [238, 130, 238],
        [127, 255, 212],
        [255, 218, 185],
        [255, 105, 180],
    ]
)
esang's avatar
esang committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# paint each point according to its pred


def paint(batched_points):
    B, N = batched_points.shape
    colored = color_map[batched_points].squeeze(2)
    return colored


best_test_miou = 0
best_test_per_cat_miou = 0

for epoch in range(args.num_epochs):
    print("Epoch #{}: ".format(epoch))
    data, preds, AvgLoss, AvgAcc, training_time = train(
278
279
        net, opt, scheduler, train_loader, dev
    )
esang's avatar
esang committed
280
    if (epoch + 1) % 5 == 0 or epoch == 0:
281
        test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)
esang's avatar
esang committed
282
283
284
285
286
        if test_miou > best_test_miou:
            best_test_miou = test_miou
            best_test_per_cat_miou = test_per_cat_miou
            if args.save_model_path:
                torch.save(net.state_dict(), args.save_model_path)
287
288
289
290
291
292
293
294
295
        print(
            "Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
            % (
                test_miou,
                best_test_miou,
                test_per_cat_miou,
                best_test_per_cat_miou,
            )
        )
esang's avatar
esang committed
296
297
298
    # Tensorboard
    if args.tensorboard:
        colored = paint(preds)
299
300
301
302
303
304
305
306
        writer.add_mesh(
            "data", vertices=data, colors=colored, global_step=epoch
        )
        writer.add_scalar(
            "training time for one epoch", training_time, global_step=epoch
        )
        writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
        writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
esang's avatar
esang committed
307
        if (epoch + 1) % 5 == 0:
308
309
310
311
            writer.add_scalar("test mIoU", test_miou, global_step=epoch)
            writer.add_scalar(
                "best test mIoU", best_test_miou, global_step=epoch
            )
esang's avatar
esang committed
312
    print()