Unverified Commit 4864a9f9 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Bug] Merge rather replace when find reverse eids (#5532)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 62b5f50a
...@@ -279,13 +279,14 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): ...@@ -279,13 +279,14 @@ def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
g.to_canonical_etype(k): g.to_canonical_etype(v) g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items() for k, v in reverse_etype_map.items()
} }
exclude_eids.update( for k, v in reverse_etype_map.items():
{ if k in exclude_eids:
reverse_etype_map[k]: v if v in exclude_eids:
for k, v in exclude_eids.items() exclude_eids[v] = F.unique(
if k in reverse_etype_map F.cat((exclude_eids[k], exclude_eids[v]), dim=0)
} )
) else:
exclude_eids[v] = exclude_eids[k]
return exclude_eids return exclude_eids
......
...@@ -515,6 +515,11 @@ def _create_heterogeneous(): ...@@ -515,6 +515,11 @@ def _create_heterogeneous():
return g, reverse_etypes, always_exclude, seed_edges return g, reverse_etypes, always_exclude, seed_edges
def _remove_duplicates(s, d):
s, d = list(zip(*list(set(zip(s.tolist(), d.tolist())))))
return torch.tensor(s, device=F.ctx()), torch.tensor(d, device=F.ctx())
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids): def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
if exclude == None: if exclude == None:
return always_exclude return always_exclude
...@@ -622,6 +627,45 @@ def test_edge_dataloader_excludes( ...@@ -622,6 +627,45 @@ def test_edge_dataloader_excludes(
break break
def test_edge_dataloader_exclusion_with_reverse_seed_nodes():
utype, etype, vtype = ("A", "AB", "B")
s = torch.randint(0, 20, (500,), device=F.ctx())
d = torch.randint(0, 20, (500,), device=F.ctx())
s, d = _remove_duplicates(s, d)
g = dgl.heterograph({("A", "AB", "B"): (s, d), ("B", "BA", "A"): (d, s)})
sampler = dgl.dataloading.as_edge_prediction_sampler(
dgl.dataloading.NeighborSampler(fanouts=[2, 2, 2]),
exclude="reverse_types",
reverse_etypes={"AB": "BA", "BA": "AB"},
)
seed_edges = {
"AB": torch.arange(g.number_of_edges("AB"), device=F.ctx()),
"BA": torch.arange(g.number_of_edges("BA"), device=F.ctx()),
}
dataloader = dgl.dataloading.DataLoader(
g,
seed_edges,
sampler,
batch_size=2,
device=F.ctx(),
shuffle=True,
drop_last=False,
)
for _, pos_graph, mfgs in dataloader:
s, d = pos_graph["AB"].edges()
AB_pos = list(zip(s.tolist(), d.tolist()))
s, d = pos_graph["BA"].edges()
BA_pos = list(zip(s.tolist(), d.tolist()))
s, d = mfgs[-1]["AB"].edges()
AB_mfg = list(zip(s.tolist(), d.tolist()))
s, d = mfgs[-1]["BA"].edges()
BA_mfg = list(zip(s.tolist(), d.tolist()))
assert all(edge not in AB_mfg for edge in AB_pos)
assert all(edge not in BA_mfg for edge in BA_pos)
def test_edge_dataloader_exclusion_without_all_reverses(): def test_edge_dataloader_exclusion_without_all_reverses():
data_dict = { data_dict = {
("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])), ("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment