"src/diffusers/pipelines/bria/__init__.py" did not exist on "413604405fddb4692a8e9a9a9fb6c353d22881ea"
main.py 4.92 KB
Newer Older
kitaev-chen's avatar
kitaev-chen committed
1
2
3
4
5
6
7
8
import sys
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

9
from dgl.data import GINDataset
10
from dataloader import GINDataLoader
Tomohiro Endo's avatar
Tomohiro Endo committed
11
from ginparser import Parser
kitaev-chen's avatar
kitaev-chen committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from gin import GIN


def train(args, net, trainloader, optimizer, criterion, epoch):
    net.train()

    running_loss = 0
    total_iters = len(trainloader)
    # setup the offset to avoid the overlap with mouse cursor
    bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout)

    for pos, (graphs, labels) in zip(bar, trainloader):
        # batch graphs will be shipped to device in forward part of model
        labels = labels.to(args.device)
26
        graphs = graphs.to(args.device)
27
        feat = graphs.ndata.pop('attr')
28
        outputs = net(graphs, feat)
kitaev-chen's avatar
kitaev-chen committed
29
30
31
32
33

        loss = criterion(outputs, labels)
        running_loss += loss.item()

        # backprop
34
35
36
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
kitaev-chen's avatar
kitaev-chen committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

        # report
        bar.set_description('epoch-{}'.format(epoch))
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss


def eval_net(args, net, dataloader, criterion):
    net.eval()

    total = 0
    total_loss = 0
    total_correct = 0

    for data in dataloader:
        graphs, labels = data
56
        graphs = graphs.to(args.device)
kitaev-chen's avatar
kitaev-chen committed
57
        labels = labels.to(args.device)
58
        feat = graphs.ndata.pop('attr')
kitaev-chen's avatar
kitaev-chen committed
59
        total += len(labels)
60
        outputs = net(graphs, feat)
kitaev-chen's avatar
kitaev-chen committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        _, predicted = torch.max(outputs.data, 1)

        total_correct += (predicted == labels.data).sum().item()
        loss = criterion(outputs, labels)
        # crossentropy(reduce=True) for default
        total_loss += loss.item() * len(labels)

    loss, acc = 1.0*total_loss / total, 1.0*total_correct / total

    net.train()

    return loss, acc


def main(args):

    # set up seeds, args.seed supported
Mufei Li's avatar
Mufei Li committed
78
79
    torch.manual_seed(seed=args.seed)
    np.random.seed(seed=args.seed)
kitaev-chen's avatar
kitaev-chen committed
80
81
82
83
84

    is_cuda = not args.disable_cuda and torch.cuda.is_available()

    if is_cuda:
        args.device = torch.device("cuda:" + str(args.device))
Mufei Li's avatar
Mufei Li committed
85
        torch.cuda.manual_seed_all(seed=args.seed)
kitaev-chen's avatar
kitaev-chen committed
86
87
88
    else:
        args.device = torch.device("cpu")

89
    dataset = GINDataset(args.dataset, not args.learn_eps, args.degree_as_nlabel)
90
    trainloader, validloader = GINDataLoader(
kitaev-chen's avatar
kitaev-chen committed
91
        dataset, batch_size=args.batch_size, device=args.device,
92
        seed=args.seed, shuffle=True,
kitaev-chen's avatar
kitaev-chen committed
93
94
95
96
97
98
99
        split_name='fold10', fold_idx=args.fold_idx).train_valid_loader()
    # or split_name='rand', split_ratio=0.7

    model = GIN(
        args.num_layers, args.num_mlp_layers,
        dataset.dim_nfeats, args.hidden_dim, dataset.gclasses,
        args.final_dropout, args.learn_eps,
100
        args.graph_pooling_type, args.neighbor_pooling_type).to(args.device)
kitaev-chen's avatar
kitaev-chen committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    criterion = nn.CrossEntropyLoss()  # defaul reduce is true
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    # it's not cost-effective to hanle the cursor and init 0
    # https://stackoverflow.com/a/23121189
    tbar = tqdm(range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
    vbar = tqdm(range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
    lrbar = tqdm(range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)

    for epoch, _, _ in zip(tbar, vbar, lrbar):

        train(args, model, trainloader, optimizer, criterion, epoch)
Mufei Li's avatar
Mufei Li committed
115
        scheduler.step()
kitaev-chen's avatar
kitaev-chen committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

        train_loss, train_acc = eval_net(
            args, model, trainloader, criterion)
        tbar.set_description(
            'train set - average loss: {:.4f}, accuracy: {:.0f}%'
            .format(train_loss, 100. * train_acc))

        valid_loss, valid_acc = eval_net(
            args, model, validloader, criterion)
        vbar.set_description(
            'valid set - average loss: {:.4f}, accuracy: {:.0f}%'
            .format(valid_loss, 100. * valid_acc))

        if not args.filename == "":
            with open(args.filename, 'a') as f:
131
                f.write('%s %s %s %s %s' % (
kitaev-chen's avatar
kitaev-chen committed
132
133
134
                    args.dataset,
                    args.learn_eps,
                    args.neighbor_pooling_type,
135
136
                    args.graph_pooling_type,
                    epoch
kitaev-chen's avatar
kitaev-chen committed
137
138
139
140
141
142
143
144
145
146
147
                ))
                f.write("\n")
                f.write("%f %f %f %f" % (
                    train_loss,
                    train_acc,
                    valid_loss,
                    valid_acc
                ))
                f.write("\n")

        lrbar.set_description(
VoVAllen's avatar
VoVAllen committed
148
149
            "Learning eps with learn_eps={}: {}".format(
                args.learn_eps, [layer.eps.data.item() for layer in model.ginlayers]))
kitaev-chen's avatar
kitaev-chen committed
150
151
152
153
154
155
156
157
158
159
160
161

    tbar.close()
    vbar.close()
    lrbar.close()


if __name__ == '__main__':
    args = Parser(description='GIN').args
    print('show all arguments configuration...')
    print(args)

    main(args)