Unverified Commit 929c99ed authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Rewrite GraphSAGE example (#1938)

parent 879e4ae5
...@@ -18,7 +18,6 @@ from dgl.nn.pytorch.conv import SAGEConv ...@@ -18,7 +18,6 @@ from dgl.nn.pytorch.conv import SAGEConv
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
def __init__(self, def __init__(self,
g,
in_feats, in_feats,
n_hidden, n_hidden,
n_classes, n_classes,
...@@ -28,27 +27,31 @@ class GraphSAGE(nn.Module): ...@@ -28,27 +27,31 @@ class GraphSAGE(nn.Module):
aggregator_type): aggregator_type):
super(GraphSAGE, self).__init__() super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.g = g self.dropout = nn.Dropout(dropout)
self.activation = activation
# input layer # input layer
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation)) self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
# output layer # output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None
def forward(self, features): def forward(self, graph, inputs):
h = features h = self.dropout(inputs)
for layer in self.layers: for l, layer in enumerate(self.layers):
h = layer(self.g, h) h = layer(graph, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h return h
def evaluate(model, features, labels, mask): def evaluate(model, graph, features, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(features) logits = model(graph, features)
logits = logits[mask] logits = logits[mask]
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
...@@ -101,19 +104,16 @@ def main(args): ...@@ -101,19 +104,16 @@ def main(args):
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create GraphSAGE model # create GraphSAGE model
model = GraphSAGE(g, model = GraphSAGE(in_feats,
in_feats,
args.n_hidden, args.n_hidden,
n_classes, n_classes,
args.n_layers, args.n_layers,
F.relu, F.relu,
args.dropout, args.dropout,
args.aggregator_type args.aggregator_type)
)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
...@@ -125,8 +125,8 @@ def main(args): ...@@ -125,8 +125,8 @@ def main(args):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = F.cross_entropy(logits[train_mask], labels[train_mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -135,13 +135,13 @@ def main(args): ...@@ -135,13 +135,13 @@ def main(args):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask) acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(), "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000)) acc, n_edges / np.mean(dur) / 1000))
print() print()
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, g, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment