Unverified Commit 1ad78fba authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update minibatch `to` to support seeds. (#7235)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent a3b09f74
...@@ -221,6 +221,10 @@ class CopyTo(IterDataPipe): ...@@ -221,6 +221,10 @@ class CopyTo(IterDataPipe):
``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be ``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
transferred. transferred.
- When ``seeds`` is not None, only ``labels``, ``compacted_seeds``,
``sampled_subgraphs``, ``node_features`` and ``edge_features`` will be
transferred.
- Otherwise, all attributes will be transferred. - Otherwise, all attributes will be transferred.
- If you want some other attributes to be transferred as well, please - If you want some other attributes to be transferred as well, please
......
...@@ -583,6 +583,20 @@ class MiniBatch: ...@@ -583,6 +583,20 @@ class MiniBatch:
"node_features", "node_features",
"edge_features", "edge_features",
] ]
elif self.seeds is not None and self.compacted_seeds is not None:
# Node/link/edge related tasks.
transfer_attrs = [
"labels",
"compacted_seeds",
"sampled_subgraphs",
"node_features",
"edge_features",
]
if self.indexes is not None:
transfer_attrs.append("indexes")
if self.labels is None:
# Layerwise inference
transfer_attrs.append("seeds")
else: else:
# Otherwise copy all the attributes to the device. # Otherwise copy all the attributes to the device.
transfer_attrs = get_attributes(self) transfer_attrs = get_attributes(self)
......
...@@ -41,7 +41,7 @@ def test_CopyTo(): ...@@ -41,7 +41,7 @@ def test_CopyTo():
], ],
) )
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches(task): def test_CopyToWithMiniBatches_original(task):
N = 16 N = 16
B = 2 B = 2
if task == "node_classification" or task == "extra_attrs": if task == "node_classification" or task == "extra_attrs":
...@@ -173,6 +173,130 @@ def test_CopyToWithMiniBatches(task): ...@@ -173,6 +173,130 @@ def test_CopyToWithMiniBatches(task):
test_data_device(datapipe.copy_to("cuda", extra_attrs)) test_data_device(datapipe.copy_to("cuda", extra_attrs))
@pytest.mark.parametrize(
"task",
[
"node_classification",
"node_inference",
"link_prediction",
"edge_classification",
"extra_attrs",
],
)
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches(task):
N = 16
B = 2
if task == "node_classification" or task == "extra_attrs":
itemset = gb.ItemSet(
(torch.arange(N), torch.arange(N)), names=("seeds", "labels")
)
elif task == "node_inference":
itemset = gb.ItemSet(torch.arange(N), names="seeds")
elif task == "link_prediction":
itemset = gb.ItemSet(
(
torch.arange(2 * N).reshape(-1, 2),
torch.arange(N),
),
names=("seeds", "labels"),
)
elif task == "edge_classification":
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("seeds", "labels"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
feature_store = gb.BasicFeatureStore(features)
datapipe = gb.ItemSampler(itemset, batch_size=B)
datapipe = gb.NeighborSampler(
datapipe,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
if task != "node_inference":
datapipe = gb.FeatureFetcher(
datapipe,
feature_store,
["a"],
)
if task == "node_classification":
copied_attrs = [
"node_features",
"edge_features",
"compacted_seeds",
"sampled_subgraphs",
"labels",
"blocks",
]
elif task == "node_inference":
copied_attrs = [
"seeds",
"compacted_seeds",
"sampled_subgraphs",
"blocks",
"labels",
]
elif task == "link_prediction" or task == "edge_classification":
copied_attrs = [
"labels",
"compacted_seeds",
"sampled_subgraphs",
"indexes",
"node_features",
"edge_features",
"blocks",
]
elif task == "extra_attrs":
copied_attrs = [
"node_features",
"edge_features",
"compacted_seeds",
"sampled_subgraphs",
"labels",
"blocks",
"seed_nodes",
]
def test_data_device(datapipe):
for data in datapipe:
print(data)
for attr in dir(data):
var = getattr(data, attr)
if isinstance(var, Mapping):
var = var[next(iter(var))]
elif isinstance(var, Iterable):
var = next(iter(var))
if (
not callable(var)
and not attr.startswith("__")
and hasattr(var, "device")
and var is not None
):
if attr in copied_attrs:
assert var.device.type == "cuda", attr
else:
assert var.device.type == "cpu", attr
if task == "extra_attrs":
extra_attrs = ["seed_nodes"]
else:
extra_attrs = None
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda", extra_attrs))
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda", extra_attrs))
def test_etype_tuple_to_str(): def test_etype_tuple_to_str():
"""Convert etype from tuple to string.""" """Convert etype from tuple to string."""
# Test for expected input. # Test for expected input.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment