Unverified Commit a9768cb3 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Minor update on the golden example (#4197)



* minor update on golden example

* update

* update

* Update README
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 6a6597a0
......@@ -14,6 +14,7 @@ class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
......@@ -68,7 +69,7 @@ def evaluate(model, graph, dataloader):
y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
def layerwise_infer(args, device, graph, nid, model, batch_size):
def layerwise_infer(device, graph, nid, model, batch_size):
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, batch_size)
......@@ -80,7 +81,7 @@ def train(args, device, g, dataset, model):
# create sampler & dataloader
train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler([10, 10, 10], # fanout for layer-0, layer-1 and layer-2
sampler = NeighborSampler([10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=['feat'],
prefetch_labels=['label'])
use_uva = (args.mode == 'mixed')
......@@ -135,10 +136,10 @@ if __name__ == '__main__':
model = SAGE(in_size, 256, out_size).to(device)
# model training
print('Training')
print('Training...')
train(args, device, g, dataset, model)
# test the model
print('Testing')
acc = layerwise_infer(args, device, g, dataset.test_idx.to(device), model, batch_size=4096)
print('Testing...')
acc = layerwise_infer(device, g, dataset.test_idx.to(device), model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment