Unverified Commit 71157b05 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix problem with ShaDowKHopSampler working with reverse edge type exclusion (#4145)

* fix

* fix

* Update utils.py
parent 794ec4a4
...@@ -12,8 +12,12 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): ...@@ -12,8 +12,12 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
Note that both arguments are numpy arrays or numpy dicts. Note that both arguments are numpy arrays or numpy dicts.
""" """
func = lambda x, y: np.isin(x, y).nonzero()[0] if not isinstance(frontier_parent_eids, Mapping):
result = recursive_apply_pair(frontier_parent_eids, exclude_eids, func) return np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]
result = {}
for k, v in frontier_parent_eids.items():
if k in exclude_eids:
result[k] = np.isin(v, exclude_eids[k]).nonzero()[0]
return recursive_apply(result, F.zerocopy_from_numpy) return recursive_apply(result, F.zerocopy_from_numpy)
class EidExcluder(object): class EidExcluder(object):
......
...@@ -345,13 +345,15 @@ def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids): ...@@ -345,13 +345,15 @@ def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
@pytest.mark.parametrize('always_exclude_flag', [False, True]) @pytest.mark.parametrize('always_exclude_flag', [False, True])
@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types']) @pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types'])
def test_edge_dataloader_excludes(exclude, always_exclude_flag): @pytest.mark.parametrize('sampler', [dgl.dataloading.MultiLayerFullNeighborSampler(1),
dgl.dataloading.ShaDowKHopSampler([5])])
@pytest.mark.parametrize('batch_size', [1, 50])
def test_edge_dataloader_excludes(exclude, always_exclude_flag, batch_size, sampler):
if exclude == 'reverse_types': if exclude == 'reverse_types':
g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous() g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
else: else:
g, reverse_eids, always_exclude, seed_edges = _create_homogeneous() g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
g = g.to(F.ctx()) g = g.to(F.ctx())
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
if not always_exclude_flag: if not always_exclude_flag:
always_exclude = None always_exclude = None
...@@ -361,13 +363,17 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag): ...@@ -361,13 +363,17 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
else exclude) else exclude)
kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None
kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
dataloader = dgl.dataloading.EdgeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, seed_edges, sampler, batch_size=50, device=F.ctx(), **kwargs) g, seed_edges, sampler, batch_size=batch_size, device=F.ctx(), use_prefetch_thread=False)
for input_nodes, pair_graph, blocks in dataloader: for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
block = blocks[0] if isinstance(blocks, list):
subg = blocks[0]
else:
subg = blocks
pair_eids = pair_graph.edata[dgl.EID] pair_eids = pair_graph.edata[dgl.EID]
block_eids = block.edata[dgl.EID] block_eids = subg.edata[dgl.EID]
edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids) edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids)
if edges_to_exclude is None: if edges_to_exclude is None:
...@@ -381,5 +387,9 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag): ...@@ -381,5 +387,9 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
else: else:
assert not np.isin(edges_to_exclude, block_eids).any() assert not np.isin(edges_to_exclude, block_eids).any()
if i == 10:
break
if __name__ == '__main__': if __name__ == '__main__':
test_node_dataloader(F.int32, 'neighbor', None) #test_node_dataloader(F.int32, 'neighbor', None)
test_edge_dataloader_excludes('reverse_types', False, 1, dgl.dataloading.ShaDowKHopSampler([5]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment