train_acm.py 4.96 KB
Newer Older
Ziniu Hu's avatar
Ziniu Hu committed
1
2
3
4
5
6
#!/usr/bin/env python
# coding: utf-8

# In[1]:


7
import argparse
Ziniu Hu's avatar
Ziniu Hu committed
8
import math
9
10
import urllib.request

Ziniu Hu's avatar
Ziniu Hu committed
11
import numpy as np
12
import scipy.io
Ziniu Hu's avatar
Ziniu Hu committed
13
from model import *
14
15

import dgl
Ziniu Hu's avatar
Ziniu Hu committed
16
17

torch.manual_seed(0)
18
19
data_url = "https://data.dgl.ai/dataset/ACM.mat"
data_file_path = "/tmp/ACM.mat"
Ziniu Hu's avatar
Ziniu Hu committed
20
21
22
23
24

urllib.request.urlretrieve(data_url, data_file_path)
data = scipy.io.loadmat(data_file_path)


25
26
27
parser = argparse.ArgumentParser(
    description="Training GNN on ogbn-products benchmark"
)
Ziniu Hu's avatar
Ziniu Hu committed
28
29


30
31
32
33
34
parser.add_argument("--n_epoch", type=int, default=200)
parser.add_argument("--n_hid", type=int, default=256)
parser.add_argument("--n_inp", type=int, default=256)
parser.add_argument("--clip", type=int, default=1.0)
parser.add_argument("--max_lr", type=float, default=1e-3)
Ziniu Hu's avatar
Ziniu Hu committed
35
36
37

args = parser.parse_args()

38

Ziniu Hu's avatar
Ziniu Hu committed
39
def get_n_params(model):
40
    pp = 0
Ziniu Hu's avatar
Ziniu Hu committed
41
    for p in list(model.parameters()):
42
        nn = 1
Ziniu Hu's avatar
Ziniu Hu committed
43
        for s in list(p.size()):
44
            nn = nn * s
Ziniu Hu's avatar
Ziniu Hu committed
45
46
47
        pp += nn
    return pp

48

Ziniu Hu's avatar
Ziniu Hu committed
49
def train(model, G):
TianleZhang's avatar
TianleZhang committed
50
51
52
    best_val_acc = torch.tensor(0)
    best_test_acc = torch.tensor(0)
    train_step = torch.tensor(0)
Ziniu Hu's avatar
Ziniu Hu committed
53
54
    for epoch in np.arange(args.n_epoch) + 1:
        model.train()
55
        logits = model(G, "paper")
Ziniu Hu's avatar
Ziniu Hu committed
56
57
58
59
60
61
62
63
64
65
        # The loss is computed only for labeled nodes.
        loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        train_step += 1
        scheduler.step(train_step)
        if epoch % 5 == 0:
            model.eval()
66
67
            logits = model(G, "paper")
            pred = logits.argmax(1).cpu()
Ziniu Hu's avatar
Ziniu Hu committed
68
            train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
69
70
            val_acc = (pred[val_idx] == labels[val_idx]).float().mean()
            test_acc = (pred[test_idx] == labels[test_idx]).float().mean()
Ziniu Hu's avatar
Ziniu Hu committed
71
72
73
            if best_val_acc < val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
74
75
76
77
78
79
80
81
82
83
84
85
86
87
            print(
                "Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)"
                % (
                    epoch,
                    optimizer.param_groups[0]["lr"],
                    loss.item(),
                    train_acc.item(),
                    val_acc.item(),
                    best_val_acc.item(),
                    test_acc.item(),
                    best_test_acc.item(),
                )
            )

Ziniu Hu's avatar
Ziniu Hu committed
88

89
device = torch.device("cuda:0")
Ziniu Hu's avatar
Ziniu Hu committed
90

91
92
93
94
95
96
97
98
99
100
G = dgl.heterograph(
    {
        ("paper", "written-by", "author"): data["PvsA"].nonzero(),
        ("author", "writing", "paper"): data["PvsA"].transpose().nonzero(),
        ("paper", "citing", "paper"): data["PvsP"].nonzero(),
        ("paper", "cited", "paper"): data["PvsP"].transpose().nonzero(),
        ("paper", "is-about", "subject"): data["PvsL"].nonzero(),
        ("subject", "has", "paper"): data["PvsL"].transpose().nonzero(),
    }
)
Ziniu Hu's avatar
Ziniu Hu committed
101
102
print(G)

103
pvc = data["PvsC"].tocsr()
Ziniu Hu's avatar
Ziniu Hu committed
104
105
106
107
108
109
110
111
112
113
114
115
p_selected = pvc.tocoo()
# generate labels
labels = pvc.indices
labels = torch.tensor(labels).long()

# generate train/val/test split
pid = p_selected.row
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long()

116
117
node_dict = {}
edge_dict = {}
Ziniu Hu's avatar
Ziniu Hu committed
118
for ntype in G.ntypes:
119
    node_dict[ntype] = len(node_dict)
Ziniu Hu's avatar
Ziniu Hu committed
120
for etype in G.etypes:
121
    edge_dict[etype] = len(edge_dict)
122
    G.edges[etype].data["id"] = (
123
        torch.ones(G.num_edges(etype), dtype=torch.long) * edge_dict[etype]
124
    )
125

Ziniu Hu's avatar
Ziniu Hu committed
126
127
#     Random initialize input feature
for ntype in G.ntypes:
128
    emb = nn.Parameter(
129
        torch.Tensor(G.num_nodes(ntype), 256), requires_grad=False
130
    )
Ziniu Hu's avatar
Ziniu Hu committed
131
    nn.init.xavier_uniform_(emb)
132
    G.nodes[ntype].data["inp"] = emb
133
134
135

G = G.to(device)

136
137
138
139
140
141
142
143
144
145
146
model = HGT(
    G,
    node_dict,
    edge_dict,
    n_inp=args.n_inp,
    n_hid=args.n_hid,
    n_out=labels.max().item() + 1,
    n_layers=2,
    n_heads=4,
    use_norm=True,
).to(device)
Ziniu Hu's avatar
Ziniu Hu committed
147
optimizer = torch.optim.AdamW(model.parameters())
148
149
150
151
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training HGT with #param: %d" % (get_n_params(model)))
Ziniu Hu's avatar
Ziniu Hu committed
152
153
154
train(model, G)


155
156
157
158
159
160
model = HeteroRGCN(
    G,
    in_size=args.n_inp,
    hidden_size=args.n_hid,
    out_size=labels.max().item() + 1,
).to(device)
Ziniu Hu's avatar
Ziniu Hu committed
161
optimizer = torch.optim.AdamW(model.parameters())
162
163
164
165
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training RGCN with #param: %d" % (get_n_params(model)))
Ziniu Hu's avatar
Ziniu Hu committed
166
167
168
train(model, G)


169
170
171
172
173
174
175
176
177
178
model = HGT(
    G,
    node_dict,
    edge_dict,
    n_inp=args.n_inp,
    n_hid=args.n_hid,
    n_out=labels.max().item() + 1,
    n_layers=0,
    n_heads=4,
).to(device)
Ziniu Hu's avatar
Ziniu Hu committed
179
optimizer = torch.optim.AdamW(model.parameters())
180
181
182
183
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, total_steps=args.n_epoch, max_lr=args.max_lr
)
print("Training MLP with #param: %d" % (get_n_params(model)))
Ziniu Hu's avatar
Ziniu Hu committed
184
train(model, G)