Unverified Commit 71157b05 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix problem with ShaDowKHopSampler working with reverse edge type exclusion (#4145)

* fix

* fix

* Update utils.py
parent 794ec4a4
......@@ -12,8 +12,12 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
Note that both arguments are numpy arrays or numpy dicts.
"""
func = lambda x, y: np.isin(x, y).nonzero()[0]
result = recursive_apply_pair(frontier_parent_eids, exclude_eids, func)
if not isinstance(frontier_parent_eids, Mapping):
return np.isin(frontier_parent_eids, exclude_eids).nonzero()[0]
result = {}
for k, v in frontier_parent_eids.items():
if k in exclude_eids:
result[k] = np.isin(v, exclude_eids[k]).nonzero()[0]
return recursive_apply(result, F.zerocopy_from_numpy)
class EidExcluder(object):
......
......@@ -345,13 +345,15 @@ def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
@pytest.mark.parametrize('always_exclude_flag', [False, True])
@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types'])
def test_edge_dataloader_excludes(exclude, always_exclude_flag):
@pytest.mark.parametrize('sampler', [dgl.dataloading.MultiLayerFullNeighborSampler(1),
dgl.dataloading.ShaDowKHopSampler([5])])
@pytest.mark.parametrize('batch_size', [1, 50])
def test_edge_dataloader_excludes(exclude, always_exclude_flag, batch_size, sampler):
if exclude == 'reverse_types':
g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
else:
g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
g = g.to(F.ctx())
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
if not always_exclude_flag:
always_exclude = None
......@@ -361,13 +363,17 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
else exclude)
kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None
kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
dataloader = dgl.dataloading.EdgeDataLoader(
g, seed_edges, sampler, batch_size=50, device=F.ctx(), **kwargs)
for input_nodes, pair_graph, blocks in dataloader:
block = blocks[0]
dataloader = dgl.dataloading.DataLoader(
g, seed_edges, sampler, batch_size=batch_size, device=F.ctx(), use_prefetch_thread=False)
for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
if isinstance(blocks, list):
subg = blocks[0]
else:
subg = blocks
pair_eids = pair_graph.edata[dgl.EID]
block_eids = block.edata[dgl.EID]
block_eids = subg.edata[dgl.EID]
edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids)
if edges_to_exclude is None:
......@@ -381,5 +387,9 @@ def test_edge_dataloader_excludes(exclude, always_exclude_flag):
else:
assert not np.isin(edges_to_exclude, block_eids).any()
if i == 10:
break
if __name__ == '__main__':
test_node_dataloader(F.int32, 'neighbor', None)
#test_node_dataloader(F.int32, 'neighbor', None)
test_edge_dataloader_excludes('reverse_types', False, 1, dgl.dataloading.ShaDowKHopSampler([5]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment