main.py 4.28 KB
Newer Older
1
2
3
4
5
import argparse
import os
import urllib
from functools import partial

6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
10
import tqdm
11

12
from dgl.data.utils import download, get_download_dir
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
13
14
15
from model import compute_loss, Model
from modelnet import ModelNet
from torch.utils.data import DataLoader
16
17

parser = argparse.ArgumentParser()
18
19
20
21
22
23
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=100)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=32)
24
25
26
27
args = parser.parse_args()

num_workers = args.num_workers
batch_size = args.batch_size
28
29
30
31
data_filename = "modelnet40-sampled-2048.h5"
local_path = args.dataset_path or os.path.join(
    get_download_dir(), data_filename
)
32
33

if not os.path.exists(local_path):
34
35
36
    download(
        "https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5", local_path
    )
37
38

CustomDataLoader = partial(
39
40
41
42
43
44
45
    DataLoader,
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

def train(model, opt, scheduler, train_loader, dev):
    scheduler.step()

    model.train()

    total_loss = 0
    num_batches = 0
    total_correct = 0
    count = 0
    with tqdm.tqdm(train_loader, ascii=True) as tq:
        for data, label in tq:
            num_examples = label.shape[0]
            data, label = data.to(dev), label.to(dev).squeeze().long()
            opt.zero_grad()
            logits = model(data)
            loss = compute_loss(logits, label)
            loss.backward()
            opt.step()

            _, preds = logits.max(1)

            num_batches += 1
            count += num_examples
            loss = loss.item()
            correct = (preds == label).sum().item()
            total_loss += loss
            total_correct += correct

75
76
77
78
79
80
81
82
83
            tq.set_postfix(
                {
                    "Loss": "%.5f" % loss,
                    "AvgLoss": "%.5f" % (total_loss / num_batches),
                    "Acc": "%.5f" % (correct / num_examples),
                    "AvgAcc": "%.5f" % (total_correct / count),
                }
            )

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

def evaluate(model, test_loader, dev):
    model.eval()

    total_correct = 0
    count = 0

    with torch.no_grad():
        with tqdm.tqdm(test_loader, ascii=True) as tq:
            for data, label in tq:
                num_examples = label.shape[0]
                data, label = data.to(dev), label.to(dev).squeeze().long()
                logits = model(data)
                _, preds = logits.max(1)

                correct = (preds == label).sum().item()
                total_correct += correct
                count += num_examples

103
104
105
106
107
108
                tq.set_postfix(
                    {
                        "Acc": "%.5f" % (correct / num_examples),
                        "AvgAcc": "%.5f" % (total_correct / count),
                    }
                )
109
110
111
112
113
114
115
116
117
118
119
120
121

    return total_correct / count


dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(20, [64, 64, 128, 256], [512, 512, 256], 40)
model = model.to(dev)
if args.load_model_path:
    model.load_state_dict(torch.load(args.load_model_path, map_location=dev))

opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

122
123
124
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    opt, args.num_epochs, eta_min=0.001
)
125
126
127
128
129
130
131
132
133
134
135

modelnet = ModelNet(local_path, 1024)

train_loader = CustomDataLoader(modelnet.train())
valid_loader = CustomDataLoader(modelnet.valid())
test_loader = CustomDataLoader(modelnet.test())

best_valid_acc = 0
best_test_acc = 0

for epoch in range(args.num_epochs):
136
    print("Epoch #%d Validating" % epoch)
137
138
139
140
141
142
143
    valid_acc = evaluate(model, valid_loader, dev)
    test_acc = evaluate(model, test_loader, dev)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        best_test_acc = test_acc
        if args.save_model_path:
            torch.save(model.state_dict(), args.save_model_path)
144
145
146
147
    print(
        "Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)"
        % (valid_acc, best_valid_acc, test_acc, best_test_acc)
    )
148
149

    train(model, opt, scheduler, train_loader, dev)