Unverified Commit da53275a authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix tgn example (#3543)

parent 03c2c6d1
...@@ -3,7 +3,7 @@ import dgl ...@@ -3,7 +3,7 @@ import dgl
from dgl.dataloading.dataloader import EdgeCollator from dgl.dataloading.dataloader import EdgeCollator
from dgl.dataloading import BlockSampler from dgl.dataloading import BlockSampler
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_blocks_storage from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_storages
from dgl.base import DGLError from dgl.base import DGLError
from functools import partial from functools import partial
...@@ -113,7 +113,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -113,7 +113,7 @@ class TemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor] eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs. The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
g_sampling : DGLGraph, optional g_sampling : DGLGraph, optional
...@@ -203,7 +203,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -203,7 +203,7 @@ class TemporalEdgeCollator(EdgeCollator):
for i, edge in enumerate(zip(self.g.edges()[0][items], self.g.edges()[1][items])): for i, edge in enumerate(zip(self.g.edges()[0][items], self.g.edges()[1][items])):
ts = pair_graph.edata['timestamp'][i] ts = pair_graph.edata['timestamp'][i]
timestamps.append(ts) timestamps.append(ts)
subg = self.block_sampler.sample_blocks(self.g_sampling, subg = self.graph_sampler.sample_blocks(self.g_sampling,
list(edge), list(edge),
timestamp=ts)[0] timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes()) subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
...@@ -213,7 +213,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -213,7 +213,7 @@ class TemporalEdgeCollator(EdgeCollator):
self.negative_sampler.k) self.negative_sampler.k)
for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())): for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())):
ts = timestamps[i] ts = timestamps[i]
subg = self.block_sampler.sample_blocks(self.g_sampling, subg = self.graph_sampler.sample_blocks(self.g_sampling,
[neg_edge[1]], [neg_edge[1]],
timestamp=ts)[0] timestamp=ts)[0]
subg.ndata['timestamp'] = ts.repeat(subg.num_nodes()) subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
...@@ -230,7 +230,7 @@ class TemporalEdgeCollator(EdgeCollator): ...@@ -230,7 +230,7 @@ class TemporalEdgeCollator(EdgeCollator):
# Copy the feature from parent graph # Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g) _pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g) _pop_subgraph_storage(result[2], self.g)
_pop_blocks_storage(result[-1], self.g_sampling) _pop_storages(result[-1], self.g_sampling)
return result return result
...@@ -248,7 +248,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): ...@@ -248,7 +248,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
eids : torch.tensor() or numpy array eids : torch.tensor() or numpy array
eids range which to be batched, it is useful to split training validation test dataset eids range which to be batched, it is useful to split training validation test dataset
block_sampler : dgl.dataloading.BlockSampler graph_sampler : dgl.dataloading.BlockSampler
temporal neighbor sampler which sample temporal and computationally depend blocks for computation temporal neighbor sampler which sample temporal and computationally depend blocks for computation
device : str device : str
...@@ -264,7 +264,8 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): ...@@ -264,7 +264,8 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
""" """
def __init__(self, g, eids, block_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs): def __init__(self, g, eids, graph_sampler, device='cpu', collator=TemporalEdgeCollator, **kwargs):
super().__init__(g, eids, graph_sampler, device, **kwargs)
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -272,7 +273,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader): ...@@ -272,7 +273,7 @@ class TemporalEdgeDataLoader(dgl.dataloading.EdgeDataLoader):
collator_kwargs[k] = v collator_kwargs[k] = v
else: else:
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
self.collator = collator(g, eids, block_sampler, **collator_kwargs) self.collator = collator(g, eids, graph_sampler, **collator_kwargs)
assert not isinstance(g, dgl.distributed.DistGraph), \ assert not isinstance(g, dgl.distributed.DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \ 'EdgeDataLoader does not support DistGraph for now. ' \
...@@ -485,7 +486,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -485,7 +486,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
eids : Tensor or dict[etype, Tensor] eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs. The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
g_sampling : DGLGraph, optional g_sampling : DGLGraph, optional
...@@ -570,7 +571,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -570,7 +571,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
pair_graph.edata[dgl.EID] = induced_edges pair_graph.edata[dgl.EID] = induced_edges
seed_nodes = pair_graph.ndata[dgl.NID] seed_nodes = pair_graph.ndata[dgl.NID]
blocks = self.block_sampler.sample_blocks(self.g_sampling, seed_nodes) blocks = self.graph_sampler.sample_blocks(self.g_sampling, seed_nodes)
blocks[0].ndata['timestamp'] = torch.zeros( blocks[0].ndata['timestamp'] = torch.zeros(
blocks[0].num_nodes()).double() blocks[0].num_nodes()).double()
input_nodes = blocks[0].edges()[1] input_nodes = blocks[0].edges()[1]
...@@ -578,7 +579,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -578,7 +579,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
# update sampler # update sampler
_src = self.g.nodes()[self.g.edges()[0][items]] _src = self.g.nodes()[self.g.edges()[0][items]]
_dst = self.g.nodes()[self.g.edges()[1][items]] _dst = self.g.nodes()[self.g.edges()[1][items]]
self.block_sampler.add_edges(_src, _dst) self.graph_sampler.add_edges(_src, _dst)
return input_nodes, pair_graph, neg_pair_graph, blocks return input_nodes, pair_graph, neg_pair_graph, blocks
def collator(self, items): def collator(self, items):
...@@ -586,7 +587,7 @@ class FastTemporalEdgeCollator(EdgeCollator): ...@@ -586,7 +587,7 @@ class FastTemporalEdgeCollator(EdgeCollator):
# Copy the feature from parent graph # Copy the feature from parent graph
_pop_subgraph_storage(result[1], self.g) _pop_subgraph_storage(result[1], self.g)
_pop_subgraph_storage(result[2], self.g) _pop_subgraph_storage(result[2], self.g)
_pop_blocks_storage(result[-1], self.g_sampling) _pop_storages(result[-1], self.g_sampling)
return result return result
...@@ -649,7 +650,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): ...@@ -649,7 +650,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
eids : Tensor or dict[etype, Tensor] eids : Tensor or dict[etype, Tensor]
The edge set in graph :attr:`g` to compute outputs. The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
g_sampling : DGLGraph, optional g_sampling : DGLGraph, optional
...@@ -701,11 +702,11 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): ...@@ -701,11 +702,11 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
A set of builtin negative samplers are provided in A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`. :ref:`the negative sampling module <api-dataloading-negative-sampling>`.
''' '''
def __init__(self, g, eids, block_sampler, g_sampling=None, exclude=None, def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None): reverse_eids=None, reverse_etypes=None, negative_sampler=None):
super(SimpleTemporalEdgeCollator,self).__init__(g,eids,block_sampler, super(SimpleTemporalEdgeCollator, self).__init__(g, eids, graph_sampler,
g_sampling,exclude,reverse_eids,reverse_etypes,negative_sampler) g_sampling, exclude, reverse_eids, reverse_etypes, negative_sampler)
self.n_layer = len(self.block_sampler.fanouts) self.n_layer = len(self.graph_sampler.fanouts)
def collate(self,items): def collate(self,items):
''' '''
...@@ -713,7 +714,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): ...@@ -713,7 +714,7 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
We sample iteratively k-times and batch them into one single subgraph. We sample iteratively k-times and batch them into one single subgraph.
''' '''
current_ts = self.g.edata['timestamp'][items[0]] #only sample edges before current timestamp current_ts = self.g.edata['timestamp'][items[0]] #only sample edges before current timestamp
self.block_sampler.ts = current_ts # restore the current timestamp to the graph sampler. self.graph_sampler.ts = current_ts # restore the current timestamp to the graph sampler.
# if link prefiction, we use a negative_sampler to generate neg-graph for loss computing. # if link prefiction, we use a negative_sampler to generate neg-graph for loss computing.
if self.negative_sampler is None: if self.negative_sampler is None:
...@@ -724,8 +725,8 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator): ...@@ -724,8 +725,8 @@ class SimpleTemporalEdgeCollator(dgl.dataloading.EdgeCollator):
# we sampling k-hop subgraph and batch them into one graph # we sampling k-hop subgraph and batch them into one graph
for i in range(self.n_layer-1): for i in range(self.n_layer-1):
self.block_sampler.frontiers[0].add_edges(*self.block_sampler.frontiers[i+1].edges()) self.graph_sampler.frontiers[0].add_edges(*self.graph_sampler.frontiers[i+1].edges())
frontier = self.block_sampler.frontiers[0] frontier = self.graph_sampler.frontiers[0]
# computing node last-update timestamp # computing node last-update timestamp
frontier.update_all(fn.copy_e('timestamp','ts'), fn.max('ts','timestamp')) frontier.update_all(fn.copy_e('timestamp','ts'), fn.max('ts','timestamp'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment