Unverified Commit 929c99ed authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Rewrite GraphSAGE example (#1938)

parent 879e4ae5
......@@ -18,7 +18,6 @@ from dgl.nn.pytorch.conv import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
......@@ -28,27 +27,31 @@ class GraphSAGE(nn.Module):
aggregator_type):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.g = g
self.dropout = nn.Dropout(dropout)
self.activation = activation
# input layer
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))
# output layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, feat_drop=dropout, activation=None)) # activation None
def forward(self, features):
h = features
for layer in self.layers:
h = layer(self.g, h)
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation None
def forward(self, graph, inputs):
h = self.dropout(inputs)
for l, layer in enumerate(self.layers):
h = layer(graph, h)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
def evaluate(model, features, labels, mask):
def evaluate(model, graph, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = model(graph, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
......@@ -101,19 +104,16 @@ def main(args):
n_edges = g.number_of_edges()
# create GraphSAGE model
model = GraphSAGE(g,
in_feats,
model = GraphSAGE(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout,
args.aggregator_type
)
args.aggregator_type)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
......@@ -125,8 +125,8 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
logits = model(g, features)
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
......@@ -135,13 +135,13 @@ def main(args):
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print()
acc = evaluate(model, features, labels, test_mask)
acc = evaluate(model, g, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment