multi_gpu_graph_prediction.py 5.64 KB
Newer Older
1
2
import argparse

3
import torch
4
import torch.distributed as dist
5
6
7
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
8
9
10
11
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from tqdm import tqdm

12
13
14
15
import dgl
import dgl.nn as dglnn
from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
16

17
18
19
20
21
22
23
24
25

class MLP(nn.Module):
    def __init__(self, in_feats):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_feats, 2 * in_feats),
            nn.BatchNorm1d(2 * in_feats),
            nn.ReLU(),
            nn.Linear(2 * in_feats, in_feats),
26
            nn.BatchNorm1d(in_feats),
27
28
29
30
31
        )

    def forward(self, h):
        return self.mlp(h)

32

33
34
35
36
class GIN(nn.Module):
    def __init__(self, n_hidden, n_output, n_layers=5):
        super().__init__()
        self.node_encoder = AtomEncoder(n_hidden)
37
38
39
        self.edge_encoders = nn.ModuleList(
            [BondEncoder(n_hidden) for _ in range(n_layers)]
        )
40
41
42
43
44

        self.pool = dglnn.AvgPooling()
        self.dropout = nn.Dropout(0.5)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
45
46
            self.layers.append(dglnn.GINEConv(MLP(n_hidden), learn_eps=True))
        self.predictor = nn.Linear(n_hidden, n_output)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

        # add virtual node
        self.virtual_emb = nn.Embedding(1, n_hidden)
        nn.init.constant_(self.virtual_emb.weight.data, 0)
        self.virtual_layers = nn.ModuleList()
        for _ in range(n_layers - 1):
            self.virtual_layers.append(MLP(n_hidden))
        self.virtual_pool = dglnn.SumPooling()

    def forward(self, g, x, x_e):
        v_emb = self.virtual_emb.weight.expand(g.batch_size, -1)
        hn = self.node_encoder(x)
        for i in range(len(self.layers)):
            v_hn = dgl.broadcast_nodes(g, v_emb)
            hn = hn + v_hn
            he = self.edge_encoders[i](x_e)
            hn = self.layers[i](g, hn, he)
            hn = F.relu(hn)
            hn = self.dropout(hn)
            if i != len(self.layers) - 1:
                v_emb_tmp = self.virtual_pool(g, hn) + v_emb
                v_emb = self.virtual_layers[i](v_emb_tmp)
                v_emb = self.dropout(F.relu(v_emb))
        hn = self.pool(g, hn)
        return self.predictor(hn)

73

74
75
76
77
78
79
80
@torch.no_grad()
def evaluate(dataloader, device, model, evaluator):
    model.eval()
    y_true = []
    y_pred = []
    for batched_graph, labels in tqdm(dataloader):
        batched_graph, labels = batched_graph.to(device), labels.to(device)
81
82
83
84
        node_feat, edge_feat = (
            batched_graph.ndata["feat"],
            batched_graph.edata["feat"],
        )
85
86
        y_hat = model(batched_graph, node_feat, edge_feat)
        y_true.append(labels.view(y_hat.shape).detach().cpu())
87
        y_pred.append(y_hat.detach().cpu())
88
89
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
90
    input_dict = {"y_true": y_true, "y_pred": y_pred}
91
92
    return evaluator.eval(input_dict)

93

94
def train(rank, world_size, dataset_name, root):
95
96
97
    dist.init_process_group(
        "nccl", "tcp://127.0.0.1:12347", world_size=world_size, rank=rank
    )
98
99
100
101
102
103
104
105
106
107
108
    torch.cuda.set_device(rank)

    dataset = AsGraphPredDataset(DglGraphPropPredDataset(dataset_name, root))
    evaluator = Evaluator(dataset_name)

    model = GIN(300, dataset.num_tasks).to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    train_dataloader = GraphDataLoader(
109
110
111
112
        dataset[dataset.train_idx], batch_size=256, use_ddp=True, shuffle=True
    )
    valid_dataloader = GraphDataLoader(dataset[dataset.val_idx], batch_size=256)
    test_dataloader = GraphDataLoader(dataset[dataset.test_idx], batch_size=256)
113
114
115
116
117
118

    for epoch in range(50):
        model.train()
        train_dataloader.set_epoch(epoch)
        for batched_graph, labels in train_dataloader:
            batched_graph, labels = batched_graph.to(rank), labels.to(rank)
119
120
121
122
            node_feat, edge_feat = (
                batched_graph.ndata["feat"],
                batched_graph.edata["feat"],
            )
123
124
125
            logits = model(batched_graph, node_feat, edge_feat)
            optimizer.zero_grad()
            is_labeled = labels == labels
126
127
128
            loss = F.binary_cross_entropy_with_logits(
                logits.float()[is_labeled], labels.float()[is_labeled]
            )
129
130
131
132
133
            loss.backward()
            optimizer.step()
        scheduler.step()

        if rank == 0:
134
135
136
137
138
139
140
141
142
143
144
            val_metric = evaluate(
                valid_dataloader, rank, model.module, evaluator
            )[evaluator.eval_metric]
            test_metric = evaluate(
                test_dataloader, rank, model.module, evaluator
            )[evaluator.eval_metric]

            print(
                f"Epoch: {epoch:03d}, Loss: {loss:.4f}, "
                f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
            )
145
146
147
148

    dist.destroy_process_group()


149
if __name__ == "__main__":
150
    parser = argparse.ArgumentParser()
151
152
153
154
155
156
157
    parser.add_argument(
        "--dataset",
        type=str,
        default="ogbg-molhiv",
        choices=["ogbg-molhiv", "ogbg-molpcba"],
        help="name of dataset (default: ogbg-molhiv)",
    )
158
    dataset_name = parser.parse_args().dataset
159
    root = "./data/OGB"
160
161
162
    DglGraphPropPredDataset(dataset_name, root)

    world_size = torch.cuda.device_count()
163
    print("Let's use", world_size, "GPUs!")
164
165
    args = (world_size, dataset_name, root)
    import torch.multiprocessing as mp
166

167
    mp.spawn(train, args=args, nprocs=world_size, join=True)