Unverified Commit 333ce36c authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add test for CopyTo extra_attrs (#6906)

parent 3b37918b
......@@ -33,6 +33,7 @@ def test_CopyTo():
"node_inference",
"link_prediction",
"edge_classification",
"extra_attrs",
"other",
],
)
......@@ -40,7 +41,7 @@ def test_CopyTo():
def test_CopyToWithMiniBatches(task):
N = 16
B = 2
if task == "node_classification":
if task == "node_classification" or task == "extra_attrs":
itemset = gb.ItemSet(
(torch.arange(N), torch.arange(N)), names=("seed_nodes", "labels")
)
......@@ -125,6 +126,15 @@ def test_CopyToWithMiniBatches(task):
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "extra_attrs":
copied_attrs = [
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"seed_nodes",
]
def test_data_device(datapipe):
for data in datapipe:
......@@ -148,11 +158,16 @@ def test_CopyToWithMiniBatches(task):
else:
assert var.device.type == "cpu"
if task == "extra_attrs":
extra_attrs = ["seed_nodes"]
else:
extra_attrs = None
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
test_data_device(gb.CopyTo(datapipe, "cuda", extra_attrs))
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda"))
test_data_device(datapipe.copy_to("cuda", extra_attrs))
def test_etype_tuple_to_str():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment