"docs/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "e82f72b3acd37bfa9f32773e8844ac7bafad2b19"
Unverified Commit 5d4f6bca authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Fix] Fix edge ID exclusion not working in EdgeDataLoader (#3412)

parent 8798872f
...@@ -108,6 +108,7 @@ class _EidExcluder(): ...@@ -108,6 +108,7 @@ class _EidExcluder():
frontier, v, etype=k, store_ids=True) frontier, v, etype=k, store_ids=True)
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID])
frontier.edata[EID] = new_eids frontier.edata[EID] = new_eids
return frontier
def _create_eid_excluder(exclude_eids, device): def _create_eid_excluder(exclude_eids, device):
...@@ -388,7 +389,7 @@ class BlockSampler(object): ...@@ -388,7 +389,7 @@ class BlockSampler(object):
if not self.exclude_edges_in_frontier: if not self.exclude_edges_in_frontier:
eid_excluder = _create_eid_excluder(exclude_eids, self.output_device) eid_excluder = _create_eid_excluder(exclude_eids, self.output_device)
if eid_excluder is not None: if eid_excluder is not None:
eid_excluder(frontier) frontier = eid_excluder(frontier)
block = transform.to_block(frontier, seed_nodes_out) block = transform.to_block(frontier, seed_nodes_out)
if self.return_eids: if self.return_eids:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment