Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
da53275a
Unverified
Commit
da53275a
authored
Nov 29, 2021
by
Jinjing Zhou
Committed by
GitHub
Nov 29, 2021
Browse files
Fix tgn example (#3543)
parent
03c2c6d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
20 deletions
+21
-20
examples/pytorch/tgn/dataloading.py
examples/pytorch/tgn/dataloading.py
+21
-20
No files found.
examples/pytorch/tgn/dataloading.py
View file @
da53275a
...
...
@@ -3,7 +3,7 @@ import dgl
from
dgl.dataloading.dataloader
import
EdgeCollator
from
dgl.dataloading
import
BlockSampler
from
dgl.dataloading.pytorch
import
_pop_subgraph_storage
,
_pop_
blocks_
storage
from
dgl.dataloading.pytorch
import
_pop_subgraph_storage
,
_pop_storage
s
from
dgl.base
import
DGLError
from
functools
import
partial
...
...
@@ -113,7 +113,7 @@ class TemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block
_sampler : dgl.dataloading.BlockSampler
graph
_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
...
...
@@ -203,7 +203,7 @@ class TemporalEdgeCollator(EdgeCollator):
for
i
,
edge
in
enumerate
(
zip
(
self
.
g
.
edges
()[
0
][
items
],
self
.
g
.
edges
()[
1
][
items
])):
ts
=
pair_graph
.
edata
[
'timestamp'
][
i
]
timestamps
.
append
(
ts
)
subg
=
self
.
block
_sampler
.
sample_blocks
(
self
.
g_sampling
,
subg
=
self
.
graph
_sampler
.
sample_blocks
(
self
.
g_sampling
,
list
(
edge
),
timestamp
=
ts
)[
0
]
subg
.
ndata
[
'timestamp'
]
=
ts
.
repeat
(
subg
.
num_nodes
())
...
...
@@ -213,7 +213,7 @@ class TemporalEdgeCollator(EdgeCollator):
self
.
negative_sampler
.
k
)
for
i
,
neg_edge
in
enumerate
(
zip
(
neg_srcdst_raw
[
0
].
tolist
(),
neg_srcdst_raw
[
1
].
tolist
())):
ts
=
timestamps
[
i
]
subg
=
self
.
block
_sampler
.
sample_blocks
(
self
.
g_sampling
,
subg
=
self
.
graph
_sampler
.
sample_blocks
(
self
.
g_sampling
,
[
neg_edge
[
1
]],
timestamp
=
ts
)[
0
]
subg
.
ndata
[
'timestamp'
]
=
ts
.
repeat
(
subg
.
num_nodes
())
...
...
@@ -230,7 +230,7 @@ class TemporalEdgeCollator(EdgeCollator):
# Copy the feature from parent graph
_pop_subgraph_storage
(
result
[
1
],
self
.
g
)
_pop_subgraph_storage
(
result
[
2
],
self
.
g
)
_pop_
blocks_
storage
(
result
[
-
1
],
self
.
g_sampling
)
_pop_storage
s
(
result
[
-
1
],
self
.
g_sampling
)
return
result
...
...
@@ -248,7 +248,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
eids : torch.tensor() or numpy array
eids range which to be batched, it is useful to split training validation test dataset
block
_sampler : dgl.dataloading.BlockSampler
graph
_sampler : dgl.dataloading.BlockSampler
temporal neighbor sampler which sample temporal and computationally depend blocks for computation
device : str
...
...
@@ -264,7 +264,8 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
"""
def
__init__
(
self
,
g
,
eids
,
block_sampler
,
device
=
'cpu'
,
collator
=
TemporalEdgeCollator
,
**
kwargs
):
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
device
=
'cpu'
,
collator
=
TemporalEdgeCollator
,
**
kwargs
):
super
().
__init__
(
g
,
eids
,
graph_sampler
,
device
,
**
kwargs
)
collator_kwargs
=
{}
dataloader_kwargs
=
{}
for
k
,
v
in
kwargs
.
items
():
...
...
@@ -272,7 +273,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
collator_kwargs
[
k
]
=
v
else
:
dataloader_kwargs
[
k
]
=
v
self
.
collator
=
collator
(
g
,
eids
,
block
_sampler
,
**
collator_kwargs
)
self
.
collator
=
collator
(
g
,
eids
,
graph
_sampler
,
**
collator_kwargs
)
assert
not
isinstance
(
g
,
dgl
.
distributed
.
DistGraph
),
\
'EdgeDataLoader does not support DistGraph for now. '
\
...
...
@@ -485,7 +486,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block
_sampler : dgl.dataloading.BlockSampler
graph
_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
...
...
@@ -570,7 +571,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
pair_graph
.
edata
[
dgl
.
EID
]
=
induced_edges
seed_nodes
=
pair_graph
.
ndata
[
dgl
.
NID
]
blocks
=
self
.
block
_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
)
blocks
=
self
.
graph
_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
)
blocks
[
0
].
ndata
[
'timestamp'
]
=
torch
.
zeros
(
blocks
[
0
].
num_nodes
()).
double
()
input_nodes
=
blocks
[
0
].
edges
()[
1
]
...
...
@@ -578,7 +579,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
# update sampler
_src
=
self
.
g
.
nodes
()[
self
.
g
.
edges
()[
0
][
items
]]
_dst
=
self
.
g
.
nodes
()[
self
.
g
.
edges
()[
1
][
items
]]
self
.
block
_sampler
.
add_edges
(
_src
,
_dst
)
self
.
graph
_sampler
.
add_edges
(
_src
,
_dst
)
return
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
def
collator
(
self
,
items
):
...
...
@@ -586,7 +587,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
# Copy the feature from parent graph
_pop_subgraph_storage
(
result
[
1
],
self
.
g
)
_pop_subgraph_storage
(
result
[
2
],
self
.
g
)
_pop_
blocks_
storage
(
result
[
-
1
],
self
.
g_sampling
)
_pop_storage
s
(
result
[
-
1
],
self
.
g_sampling
)
return
result
...
...
@@ -649,7 +650,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs.
block
_sampler : dgl.dataloading.BlockSampler
graph
_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
g_sampling : DGLGraph, optional
...
...
@@ -701,11 +702,11 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
'''
def
__init__
(
self
,
g
,
eids
,
block
_sampler
,
g_sampling
=
None
,
exclude
=
None
,
def
__init__
(
self
,
g
,
eids
,
graph
_sampler
,
g_sampling
=
None
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
):
super
(
SimpleTemporalEdgeCollator
,
self
).
__init__
(
g
,
eids
,
block
_sampler
,
g_sampling
,
exclude
,
reverse_eids
,
reverse_etypes
,
negative_sampler
)
self
.
n_layer
=
len
(
self
.
block
_sampler
.
fanouts
)
super
(
SimpleTemporalEdgeCollator
,
self
).
__init__
(
g
,
eids
,
graph
_sampler
,
g_sampling
,
exclude
,
reverse_eids
,
reverse_etypes
,
negative_sampler
)
self
.
n_layer
=
len
(
self
.
graph
_sampler
.
fanouts
)
def
collate
(
self
,
items
):
'''
...
...
@@ -713,7 +714,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
We sample iteratively k-times and batch them into one single subgraph.
'''
current_ts
=
self
.
g
.
edata
[
'timestamp'
][
items
[
0
]]
#only sample edges before current timestamp
self
.
block
_sampler
.
ts
=
current_ts
# restore the current timestamp to the graph sampler.
self
.
graph
_sampler
.
ts
=
current_ts
# restore the current timestamp to the graph sampler.
# if link prefiction, we use a negative_sampler to generate neg-graph for loss computing.
if
self
.
negative_sampler
is
None
:
...
...
@@ -724,8 +725,8 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
# we sampling k-hop subgraph and batch them into one graph
for
i
in
range
(
self
.
n_layer
-
1
):
self
.
block
_sampler
.
frontiers
[
0
].
add_edges
(
*
self
.
block
_sampler
.
frontiers
[
i
+
1
].
edges
())
frontier
=
self
.
block
_sampler
.
frontiers
[
0
]
self
.
graph
_sampler
.
frontiers
[
0
].
add_edges
(
*
self
.
graph
_sampler
.
frontiers
[
i
+
1
].
edges
())
frontier
=
self
.
graph
_sampler
.
frontiers
[
0
]
# computing node last-update timestamp
frontier
.
update_all
(
fn
.
copy_e
(
'timestamp'
,
'ts'
),
fn
.
max
(
'ts'
,
'timestamp'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment