Commit 3f464591 authored by Aymen Waheb's avatar Aymen Waheb Committed by Mufei Li
Browse files

[Model] Add edge dropout to APPNP (#493)

* [Model] Add edge dropout to APPNP

[Model] Add edge dropout to APPNP

* [Model] Refactor the sampling examples (#498)

* reorganize sampling code.

* speedup gcn_ns.

* speed up gcn_cv

* fix graphsage_cv.

* undo the modification.

* accel training.

* update readme.

* [Model] Add edge dropout to APPNP

[Model] Add edge dropout to APPNP

update
parent 6124667f
......@@ -16,10 +16,11 @@ Contributors
* [@hbsun2113](https://github.com/hbsun2113): GraphSAGE in Pytorch
* [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch
* [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch
* [Aymen Waheb](https://github.com/aymenwah): APPNP in Pytorch
Other improvement
* [Brett Koonce](https://github.com/brettkoonce)
* [@giuseppefutia](https://github.com/giuseppefutia)
* [@mori97](https://github.com/mori97)
* Hao Jin
* [@aymenwah](https://github.com/aymenwah)
......@@ -29,8 +29,4 @@ python train.py --dataset cora --gpu 0
* citeseer: 0.715 (paper: 0.757)
* pubmed: 0.793 (paper: 0.797)
Differences from the original implementation
---------
- This implementation does not perform dropout on adjacency matrices during propagation step.
- Experiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper)
Experiments were done on dgl datasets (GCN settings) which are different from those used in the original implementation. (discrepancies are detailed in experimental section of the original paper)
......@@ -5,11 +5,49 @@ References
Paper: https://arxiv.org/abs/1810.05997
Author's code: https://github.com/klicperajo/ppnp
"""
import torch
import torch.nn as nn
import dgl.function as fn
class GraphPropagation(nn.Module):
def __init__(self,
g,
edge_drop,
alpha,
k):
super(GraphPropagation, self).__init__()
self.g = g
self.alpha = alpha
self.k = k
if edge_drop:
self.edge_drop = nn.Dropout(edge_drop)
else:
self.edge_drop = 0.
def forward(self, h):
self.cached_h = h
for _ in range(self.k):
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
if self.edge_drop:
# performing edge dropout
ed = self.edge_drop(torch.ones((self.g.number_of_edges(), 1)))
self.g.edata['d'] = ed
self.g.update_all(fn.src_mul_edge(src='h', edge='d', out='m'),
fn.sum(msg='m', out='h'))
else:
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree
h = h * self.g.ndata['norm']
# update h using teleport probability alpha
h = h * (1 - self.alpha) + self.cached_h * self.alpha
return h
class APPNP(nn.Module):
def __init__(self,
g,
......@@ -17,12 +55,12 @@ class APPNP(nn.Module):
hiddens,
n_classes,
activation,
dropout,
feat_drop,
edge_drop,
alpha,
k):
super(APPNP, self).__init__()
self.layers = nn.ModuleList()
self.g = g
# input layer
self.layers.append(nn.Linear(in_feats, hiddens[0]))
# hidden layers
......@@ -31,12 +69,12 @@ class APPNP(nn.Module):
# output layer
self.layers.append(nn.Linear(hiddens[-1], n_classes))
self.activation = activation
if dropout:
self.dropout = nn.Dropout(p=dropout)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.dropout = 0.
self.K = k
self.alpha = alpha
self.feat_drop = lambda x: x
self.propagate = GraphPropagation(g, edge_drop, alpha, k)
self.reset_parameters()
def reset_parameters(self):
for layer in self.layers:
......@@ -45,26 +83,11 @@ class APPNP(nn.Module):
def forward(self, features):
# prediction step
h = features
if self.dropout:
h = self.dropout(h)
h = self.feat_drop(h)
h = self.activation(self.layers[0](h))
for layer in self.layers[1:-1]:
h = self.activation(layer(h))
if self.dropout:
h = self.layers[-1](self.dropout(h))
# propagation step without dropout on adjacency matrices
self.cached_h = h
for _ in range(self.K):
# normalization by square root of src degree
h = h * self.g.ndata['norm']
self.g.ndata['h'] = h
# message-passing without performing adjacency dropout
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
h = self.g.ndata.pop('h')
# normalization by square root of dst degree
h = h * self.g.ndata['norm']
# update h using teleport probability alpha
h = h * (1 - self.alpha) + self.cached_h * self.alpha
h = self.layers[-1](self.feat_drop(h))
# propagation step
h = self.propagate(h)
return h
......@@ -8,6 +8,7 @@ from dgl.data import register_data_args, load_data
import dgl
from appnp import APPNP
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
......@@ -18,6 +19,7 @@ def evaluate(model, features, labels, mask):
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
......@@ -72,13 +74,13 @@ def main(args):
args.hidden_sizes,
n_classes,
F.relu,
args.dropout,
args.in_drop,
args.edge_drop,
args.alpha,
args.k)
if cuda:
model.cuda()
model.reset_parameters()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
......@@ -105,7 +107,7 @@ def main(args):
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print()
......@@ -116,8 +118,10 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='APPNP')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--in-drop", type=float, default=0.5,
help="input feature dropout")
parser.add_argument("--edge-drop", type=float, default=0.5,
help="edge propagation dropout")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment