main.py 3.87 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import hydra
from tqdm import tqdm
from omegaconf import OmegaConf
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
7
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm

rusty1s's avatar
rusty1s committed
8
9
from torch_geometric_autoscale import (get_data, metis, permute, models,
                                       SubgraphLoader, compute_micro_f1)
rusty1s's avatar
rusty1s committed
10
11
12
13
14

torch.manual_seed(123)
criterion = torch.nn.CrossEntropyLoss()


rusty1s's avatar
rusty1s committed
15
def train(run, model, loader, optimizer, grad_norm=None):
rusty1s's avatar
rusty1s committed
16
17
18
    model.train()

    total_loss = total_examples = 0
rusty1s's avatar
rusty1s committed
19
20
21
    for batch, batch_size, n_id, _, _ in loader:
        batch = batch.to(model.device)
        n_id = n_id.to(model.device)
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
        mask = batch.train_mask[:batch_size]
        mask = mask[:, run] if mask.dim() == 2 else mask
rusty1s's avatar
rusty1s committed
25
26
27
28
        if mask.sum() == 0:
            continue

        optimizer.zero_grad()
rusty1s's avatar
rusty1s committed
29
30
        out = model(batch.x, batch.adj_t, batch_size, n_id)
        loss = criterion(out[mask], batch.y[:batch_size][mask])
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39
40
41
42
        loss.backward()
        if grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
        optimizer.step()

        total_loss += float(loss) * int(mask.sum())
        total_examples += int(mask.sum())

    return total_loss / total_examples


@torch.no_grad()
rusty1s's avatar
rusty1s committed
43
def test(run, model, data):
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
50
51
52
    model.eval()

    val_mask = data.val_mask
    val_mask = val_mask[:, run] if val_mask.dim() == 2 else val_mask

    test_mask = data.test_mask
    test_mask = test_mask[:, run] if test_mask.dim() == 2 else test_mask

    out = model(data.x, data.adj_t)
rusty1s's avatar
rusty1s committed
53
54
    val_acc = compute_micro_f1(out, data.y, val_mask)
    test_acc = compute_micro_f1(out, data.y, test_mask)
rusty1s's avatar
rusty1s committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    return val_acc, test_acc


@hydra.main(config_path='conf', config_name='config')
def main(conf):
    model_name, dataset_name = conf.model.name, conf.dataset.name
    conf.model.params = conf.model.params[dataset_name]
    params = conf.model.params
    print(OmegaConf.to_yaml(conf))
    if isinstance(params.grad_norm, str):
        params.grad_norm = None

    device = f'cuda:{conf.device}' if torch.cuda.is_available() else 'cpu'

    data, in_channels, out_channels = get_data(conf.root, dataset_name)
    if conf.model.norm:
        data.adj_t = gcn_norm(data.adj_t)
    elif conf.model.loop:
        data.adj_t = data.adj_t.set_diag()

rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
    perm, ptr = metis(data.adj_t, num_parts=params.num_parts, log=True)
    data = permute(data, perm, log=True)

    loader = SubgraphLoader(data, ptr, batch_size=params.batch_size,
                            shuffle=True, num_workers=params.num_workers,
                            persistent_workers=params.num_workers > 0)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
    data = data.clone().to(device)  # Let's just store all data on GPU...
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89

    GNN = getattr(models, model_name)
    model = GNN(
        num_nodes=data.num_nodes,
        in_channels=in_channels,
        out_channels=out_channels,
rusty1s's avatar
rusty1s committed
90
        device=device,  # ... and put histories on GPU as well.
rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        **params.architecture,
    ).to(device)

    results = torch.empty(params.runs)
    pbar = tqdm(total=params.runs * params.epochs)
    for run in range(params.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam([
            dict(params=model.reg_modules.parameters(),
                 weight_decay=params.reg_weight_decay),
            dict(params=model.nonreg_modules.parameters(),
                 weight_decay=params.nonreg_weight_decay)
        ], lr=params.lr)

rusty1s's avatar
rusty1s committed
105
        test(0, model, data)  # Fill history.
rusty1s's avatar
rusty1s committed
106
107
108

        best_val_acc = 0
        for epoch in range(params.epochs):
rusty1s's avatar
rusty1s committed
109
110
            train(run, model, loader, optimizer, params.grad_norm)
            val_acc, test_acc = test(run, model, data)
rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
117
118
119
120
121
122
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                results[run] = test_acc

            pbar.set_description(f'Mini Acc: {100 * results[run]:.2f}')
            pbar.update(1)
    pbar.close()
    print(f'Mini Acc: {100 * results.mean():.2f} ± {100 * results.std():.2f}')


if __name__ == "__main__":
    main()