Unverified Commit fde1896f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix `to` when using `seeds`. (#7245)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3885da25
......@@ -591,15 +591,10 @@ class MiniBatch:
"sampled_subgraphs",
"node_features",
"edge_features",
"compacted_seeds",
"indexes",
"seeds",
]
# Link/edge related tasks.
if self.compacted_seeds is not None:
transfer_attrs.append("compacted_seeds")
if self.indexes is not None:
transfer_attrs.append("indexes")
if self.labels is None:
# Layerwise inference
transfer_attrs.append("seeds")
else:
# Otherwise copy all the attributes to the device.
transfer_attrs = get_attributes(self)
......
......@@ -234,6 +234,7 @@ def test_CopyToWithMiniBatches(task):
"sampled_subgraphs",
"labels",
"blocks",
"seeds",
]
elif task == "node_inference":
copied_attrs = [
......@@ -251,6 +252,7 @@ def test_CopyToWithMiniBatches(task):
"node_features",
"edge_features",
"blocks",
"seeds",
]
elif task == "extra_attrs":
copied_attrs = [
......@@ -260,6 +262,7 @@ def test_CopyToWithMiniBatches(task):
"labels",
"blocks",
"seed_nodes",
"seeds",
]
def test_data_device(datapipe):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment