train_ppi.py 4.22 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.nn as dglnn
2
3
import numpy as np
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
Xiangkun Hu's avatar
Xiangkun Hu committed
6
from dgl.data.ppi import PPIDataset
7
from dgl.dataloading import GraphDataLoader
8
from sklearn.metrics import f1_score
9

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10

11
12
13
14
15
class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads):
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # three-layer GAT
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        self.gat_layers.append(
            dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu)
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[0],
                hid_size,
                heads[1],
                residual=True,
                activation=F.elu,
            )
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[1],
                out_size,
                heads[2],
                residual=True,
                activation=None,
            )
        )
37

38
39
40
41
    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
42
            if i == 2:  # last layer
43
                h = h.mean(1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
            else:  # other layer(s)
45
46
                h = h.flatten(1)
        return h
47

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
48

49
50
51
52
53
def evaluate(g, features, labels, model):
    model.eval()
    with torch.no_grad():
        output = model(g, features)
        pred = np.where(output.data.cpu().numpy() >= 0, 1, 0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
54
        score = f1_score(labels.data.cpu().numpy(), pred, average="micro")
55
56
        return score

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57

58
59
60
61
def evaluate_in_batches(dataloader, device, model):
    total_score = 0
    for batch_id, batched_graph in enumerate(dataloader):
        batched_graph = batched_graph.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
63
        features = batched_graph.ndata["feat"]
        labels = batched_graph.ndata["label"]
64
65
        score = evaluate(batched_graph, features, labels, model)
        total_score += score
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
66
67
    return total_score / (batch_id + 1)  # return average score

68

69
70
71
72
73
def train(train_dataloader, val_dataloader, device, model):
    # define loss function and optimizer
    loss_fcn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)

74
    # training loop
75
    for epoch in range(400):
76
        model.train()
77
78
79
80
81
        logits = []
        total_loss = 0
        # mini-batch loop
        for batch_id, batched_graph in enumerate(train_dataloader):
            batched_graph = batched_graph.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
            features = batched_graph.ndata["feat"].float()
            labels = batched_graph.ndata["label"].float()
84
85
            logits = model(batched_graph, features)
            loss = loss_fcn(logits, labels)
86
87
88
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
89
            total_loss += loss.item()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
90
91
92
93
94
        print(
            "Epoch {:05d} | Loss {:.4f} |".format(
                epoch, total_loss / (batch_id + 1)
            )
        )
95

96
        if (epoch + 1) % 5 == 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
98
99
100
101
102
103
104
            avg_score = evaluate_in_batches(
                val_dataloader, device, model
            )  # evaluate F1-score instead of loss
            print(
                "                            Acc. (F1-score) {:.4f} ".format(
                    avg_score
                )
            )
105

106

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
107
108
109
if __name__ == "__main__":
    print(f"Training PPI Dataset with DGL built-in GATConv module.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110

111
    # load and preprocess datasets
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
112
113
114
115
    train_dataset = PPIDataset(mode="train")
    val_dataset = PPIDataset(mode="valid")
    test_dataset = PPIDataset(mode="test")
    features = train_dataset[0].ndata["feat"]
116
117

    # create GAT model
118
119
    in_size = features.shape[1]
    out_size = train_dataset.num_labels
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
120
    model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device)
121

122
    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
123
    print("Training...")
124
125
126
    train_dataloader = GraphDataLoader(train_dataset, batch_size=2)
    val_dataloader = GraphDataLoader(val_dataset, batch_size=2)
    train(train_dataloader, val_dataloader, device, model)
127

128
    # test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129
    print("Testing...")
130
131
132
    test_dataloader = GraphDataLoader(test_dataset, batch_size=2)
    avg_score = evaluate_in_batches(test_dataloader, device, model)
    print("Test Accuracy (F1-score) {:.4f}".format(avg_score))