Unverified Commit fe3d29ac authored by kkranen's avatar kkranen Committed by GitHub
Browse files

[Dataloading] Ignore edge types without reverse types in edge dataloader (#5411)



* Bugfix for reverse edge issue.

* Resolved copy/paste transcription error

* lint

* Added unit tests for graph with not all reverses.

* linting + added backend device test

---------
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent ae4a5b73
...@@ -280,7 +280,11 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): ...@@ -280,7 +280,11 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
for k, v in reverse_etype_map.items() for k, v in reverse_etype_map.items()
} }
exclude_eids.update( exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()} {
reverse_etype_map[k]: v
for k, v in exclude_eids.items()
if k in reverse_etype_map
}
) )
return exclude_eids return exclude_eids
......
...@@ -622,6 +622,40 @@ def test_edge_dataloader_excludes( ...@@ -622,6 +622,40 @@ def test_edge_dataloader_excludes(
break break
def test_edge_dataloader_exclusion_without_all_reverses():
data_dict = {
("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
("B", "BA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
("B", "BC", "C"): (torch.tensor([0]), torch.tensor([0])),
("C", "CA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
}
g = dgl.heterograph(data_dict=data_dict)
block_sampler = dgl.dataloading.MultiLayerNeighborSampler(
fanouts=[1], replace=True
)
block_sampler = dgl.dataloading.as_edge_prediction_sampler(
block_sampler,
exclude="reverse_types",
reverse_etypes={"AB": "BA"},
)
d = dgl.dataloading.DataLoader(
graph=g,
indices={
"AB": torch.tensor([0]),
"BC": torch.tensor([0]),
},
graph_sampler=block_sampler,
batch_size=2,
shuffle=True,
drop_last=False,
num_workers=0,
device=F.ctx(),
use_ddp=False,
)
next(iter(d))
def dummy_worker_init_fn(worker_id): def dummy_worker_init_fn(worker_id):
pass pass
...@@ -647,3 +681,4 @@ if __name__ == "__main__": ...@@ -647,3 +681,4 @@ if __name__ == "__main__":
test_edge_dataloader_excludes( test_edge_dataloader_excludes(
"reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5]) "reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
) )
test_edge_dataloader_exclusion_without_all_reverses()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment