Unverified Commit 744896e2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix examples (#4016)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent bdaccc82
...@@ -127,6 +127,6 @@ model.eval() ...@@ -127,6 +127,6 @@ model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, 4096, 0, 'cpu') pred = model.inference(graph, device, 4096, 0, 'cpu')
pred = pred[test_idx].to(device) pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx] label = graph.ndata['label'][test_idx].to(device)
acc = MF.accuracy(pred, label) acc = MF.accuracy(pred, label)
print('Test acc:', acc.item()) print('Test acc:', acc.item())
...@@ -3,7 +3,7 @@ import dgl ...@@ -3,7 +3,7 @@ import dgl
from dgl._dataloading.dataloader import EdgeCollator from dgl._dataloading.dataloader import EdgeCollator
from dgl._dataloading import BlockSampler from dgl._dataloading import BlockSampler
from dgl._dataloading.pytorch import _pop_subgraph_storage, _pop_storages from dgl._dataloading.pytorch import _pop_subgraph_storage, _pop_storages, EdgeDataLoader
from dgl.base import DGLError from dgl.base import DGLError
from functools import partial from functools import partial
...@@ -234,7 +234,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -234,7 +234,7 @@ class TemporalEdgeCollator(EdgeCollator):
return result return result
class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): class TemporalEdgeDataLoader(EdgeDataLoader):
""" TemporalEdgeDataLoader is an iteratable object to generate blocks for temporal embedding """ TemporalEdgeDataLoader is an iteratable object to generate blocks for temporal embedding
as well as pos and neg pair graph for memory update. as well as pos and neg pair graph for memory update.
...@@ -600,7 +600,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -600,7 +600,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
# "APAN: Asynchronous Propagation Attention Network for Real-time Temporal Graph Embedding" # "APAN: Asynchronous Propagation Attention Network for Real-time Temporal Graph Embedding"
# that will be appeared in SIGMOD 21, code repo https://github.com/WangXuhongCN/APAN # that will be appeared in SIGMOD 21, code repo https://github.com/WangXuhongCN/APAN
class SimpleTemporalSampler(dgl.dataloading.BlockSampler): class SimpleTemporalSampler(BlockSampler):
''' '''
Simple Temporal Sampler just choose the edges that happen before the current timestamp, to build the subgraph of the corresponding nodes. Simple Temporal Sampler just choose the edges that happen before the current timestamp, to build the subgraph of the corresponding nodes.
And then the sampler uses the simplest static graph neighborhood sampling methods. And then the sampler uses the simplest static graph neighborhood sampling methods.
...@@ -637,7 +637,7 @@ class SimpleTemporalSampler(dgl.dataloading.BlockSampler): ...@@ -637,7 +637,7 @@ class SimpleTemporalSampler(dgl.dataloading.BlockSampler):
return frontier return frontier
class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): class SimpleTemporalEdgeCollator(EdgeCollator):
''' '''
Temporal Edge collator merge the edges specified by eid: items Temporal Edge collator merge the edges specified by eid: items
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment