"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "20734637f0afab9e5ad32ae04c31a4ddc99fa3ba"
Unverified Commit a9768cb3 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Minor update on the golden example (#4197)



* minor update on golden example

* update

* update

* Update README
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 6a6597a0
...@@ -14,6 +14,7 @@ class SAGE(nn.Module): ...@@ -14,6 +14,7 @@ class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size): def __init__(self, in_size, hid_size, out_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
...@@ -68,7 +69,7 @@ def evaluate(model, graph, dataloader): ...@@ -68,7 +69,7 @@ def evaluate(model, graph, dataloader):
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
def layerwise_infer(args, device, graph, nid, model, batch_size): def layerwise_infer(device, graph, nid, model, batch_size):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, batch_size) pred = model.inference(graph, device, batch_size)
...@@ -80,7 +81,7 @@ def train(args, device, g, dataset, model): ...@@ -80,7 +81,7 @@ def train(args, device, g, dataset, model):
# create sampler & dataloader # create sampler & dataloader
train_idx = dataset.train_idx.to(device) train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device) val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler([10, 10, 10], # fanout for layer-0, layer-1 and layer-2 sampler = NeighborSampler([10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=['feat'], prefetch_node_feats=['feat'],
prefetch_labels=['label']) prefetch_labels=['label'])
use_uva = (args.mode == 'mixed') use_uva = (args.mode == 'mixed')
...@@ -135,10 +136,10 @@ if __name__ == '__main__': ...@@ -135,10 +136,10 @@ if __name__ == '__main__':
model = SAGE(in_size, 256, out_size).to(device) model = SAGE(in_size, 256, out_size).to(device)
# model training # model training
print('Training') print('Training...')
train(args, device, g, dataset, model) train(args, device, g, dataset, model)
# test the model # test the model
print('Testing') print('Testing...')
acc = layerwise_infer(args, device, g, dataset.test_idx.to(device), model, batch_size=4096) acc = layerwise_infer(device, g, dataset.test_idx.to(device), model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment