Unverified Commit fde1896f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix `to` when using `seeds`. (#7245)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3885da25
...@@ -591,15 +591,10 @@ class MiniBatch: ...@@ -591,15 +591,10 @@ class MiniBatch:
"sampled_subgraphs", "sampled_subgraphs",
"node_features", "node_features",
"edge_features", "edge_features",
"compacted_seeds",
"indexes",
"seeds",
] ]
# Link/edge related tasks.
if self.compacted_seeds is not None:
transfer_attrs.append("compacted_seeds")
if self.indexes is not None:
transfer_attrs.append("indexes")
if self.labels is None:
# Layerwise inference
transfer_attrs.append("seeds")
else: else:
# Otherwise copy all the attributes to the device. # Otherwise copy all the attributes to the device.
transfer_attrs = get_attributes(self) transfer_attrs = get_attributes(self)
......
...@@ -234,6 +234,7 @@ def test_CopyToWithMiniBatches(task): ...@@ -234,6 +234,7 @@ def test_CopyToWithMiniBatches(task):
"sampled_subgraphs", "sampled_subgraphs",
"labels", "labels",
"blocks", "blocks",
"seeds",
] ]
elif task == "node_inference": elif task == "node_inference":
copied_attrs = [ copied_attrs = [
...@@ -251,6 +252,7 @@ def test_CopyToWithMiniBatches(task): ...@@ -251,6 +252,7 @@ def test_CopyToWithMiniBatches(task):
"node_features", "node_features",
"edge_features", "edge_features",
"blocks", "blocks",
"seeds",
] ]
elif task == "extra_attrs": elif task == "extra_attrs":
copied_attrs = [ copied_attrs = [
...@@ -260,6 +262,7 @@ def test_CopyToWithMiniBatches(task): ...@@ -260,6 +262,7 @@ def test_CopyToWithMiniBatches(task):
"labels", "labels",
"blocks", "blocks",
"seed_nodes", "seed_nodes",
"seeds",
] ]
def test_data_device(datapipe): def test_data_device(datapipe):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment