Commit ad9da36a authored by Zhengwei's avatar Zhengwei Committed by Zihao Ye
Browse files

[Model] Add DGI Model (#501)

parent fe7d5e9b
Deep Graph Infomax (DGI)
========================
- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341)
- Author's code repo (in Pytorch):
[https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI)
Dependencies
------------
- PyTorch 0.4.1+
- requests
```bash
pip install torch requests
```
How to run
----------
Run with following:
```bash
python train.py --dataset=cora --gpu=0 --self-loop
```
```bash
python train.py --dataset=citeseer --gpu=0
```
```bash
python train.py --dataset=pubmed --gpu=0
```
Results
-------
* cora: ~81.6 (81.2-82.1) (paper: 82.3)
* citeseer: ~69.4 (paper: 71.8)
* pubmed: ~76.1 (paper: 76.8)
"""
Deep Graph Infomax in DGL
References
----------
Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI
"""
import torch
import torch.nn as nn
import math
from gcn import GCN
class Encoder(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(Encoder, self).__init__()
self.g = g
self.conv = GCN(g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout)
def forward(self, features, corrupt=False):
if corrupt:
perm = torch.randperm(self.g.number_of_nodes())
features = features[perm]
features = self.conv(features)
return features
class Discriminator(nn.Module):
def __init__(self, n_hidden):
super(Discriminator, self).__init__()
self.weight = nn.Parameter(torch.Tensor(n_hidden, n_hidden))
self.reset_parameters()
def uniform(self, size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
def reset_parameters(self):
size = self.weight.size(0)
self.uniform(size, self.weight)
def forward(self, features, summary):
features = torch.matmul(features, torch.matmul(self.weight, summary))
return features
class DGI(nn.Module):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(DGI, self).__init__()
self.encoder = Encoder(g, in_feats, n_hidden, n_layers, activation, dropout)
self.discriminator = Discriminator(n_hidden)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, features):
positive = self.encoder(features, corrupt=False)
negative = self.encoder(features, corrupt=True)
summary = torch.sigmoid(positive.mean(dim=0))
positive = self.discriminator(positive, summary)
negative = self.discriminator(negative, summary)
l1 = self.loss(positive, torch.ones_like(positive))
l2 = self.loss(negative, torch.zeros_like(negative))
return l1 + l2
class Classifier(nn.Module):
def __init__(self, n_hidden, n_classes):
super(Classifier, self).__init__()
self.fc = nn.Linear(n_hidden, n_classes)
self.reset_parameters()
def reset_parameters(self):
self.fc.reset_parameters()
def forward(self, features):
features = self.fc(features)
return torch.log_softmax(features, dim=-1)
"""
This code was copied from the GCN implementation in DGL examples.
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
class GCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = nn.Dropout(p=dropout)
def forward(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
return h
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgi import DGI, Classifier
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
# graph preprocess
g = data.graph
# add self loop
if args.self_loop:
g.remove_edges_from(g.selfloop_edges())
g.add_edges_from(zip(g.nodes(), g.nodes()))
g = DGLGraph(g)
n_edges = g.number_of_edges()
# create DGI model
dgi = DGI(g,
in_feats,
args.n_hidden,
args.n_layers,
nn.PReLU(args.n_hidden),
args.dropout)
if cuda:
dgi.cuda()
dgi_optimizer = torch.optim.Adam(dgi.parameters(),
lr=args.dgi_lr,
weight_decay=args.weight_decay)
# train deep graph infomax
cnt_wait = 0
best = 1e9
best_t = 0
dur = []
for epoch in range(args.n_dgi_epochs):
dgi.train()
if epoch >= 3:
t0 = time.time()
dgi_optimizer.zero_grad()
loss = dgi(features)
loss.backward()
dgi_optimizer.step()
if loss < best:
best = loss
best_t = epoch
cnt_wait = 0
torch.save(dgi.state_dict(), 'best_dgi.pkl')
else:
cnt_wait += 1
if cnt_wait == args.patience:
print('Early stopping!')
break
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
n_edges / np.mean(dur) / 1000))
# create classifier model
classifier = Classifier(args.n_hidden, n_classes)
if cuda:
classifier.cuda()
classifier_optimizer = torch.optim.Adam(classifier.parameters(),
lr=args.classifier_lr,
weight_decay=args.weight_decay)
# train classifier
print('Loading {}th epoch'.format(best_t))
dgi.load_state_dict(torch.load('best_dgi.pkl'))
embeds = dgi.encoder(features, corrupt=False)
embeds = embeds.detach()
dur = []
for epoch in range(args.n_classifier_epochs):
classifier.train()
if epoch >= 3:
t0 = time.time()
classifier_optimizer.zero_grad()
preds = classifier(embeds)
loss = F.nll_loss(preds[train_mask], labels[train_mask])
loss.backward()
classifier_optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(classifier, embeds, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print()
acc = evaluate(classifier, embeds, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGI')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--dgi-lr", type=float, default=1e-3,
help="dgi learning rate")
parser.add_argument("--classifier-lr", type=float, default=1e-2,
help="classifier learning rate")
parser.add_argument("--n-dgi-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-classifier-epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=512,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=0.,
help="Weight for L2 loss")
parser.add_argument("--patience", type=int, default=20,
help="early stop patience condition")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)
main(args)
......@@ -383,7 +383,7 @@ def _normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = np.inf
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment