Unverified Commit 333ce36c authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add test for CopyTo extra_attrs (#6906)

parent 3b37918b
...@@ -33,6 +33,7 @@ def test_CopyTo(): ...@@ -33,6 +33,7 @@ def test_CopyTo():
"node_inference", "node_inference",
"link_prediction", "link_prediction",
"edge_classification", "edge_classification",
"extra_attrs",
"other", "other",
], ],
) )
...@@ -40,7 +41,7 @@ def test_CopyTo(): ...@@ -40,7 +41,7 @@ def test_CopyTo():
def test_CopyToWithMiniBatches(task): def test_CopyToWithMiniBatches(task):
N = 16 N = 16
B = 2 B = 2
if task == "node_classification": if task == "node_classification" or task == "extra_attrs":
itemset = gb.ItemSet( itemset = gb.ItemSet(
(torch.arange(N), torch.arange(N)), names=("seed_nodes", "labels") (torch.arange(N), torch.arange(N)), names=("seed_nodes", "labels")
) )
...@@ -125,6 +126,15 @@ def test_CopyToWithMiniBatches(task): ...@@ -125,6 +126,15 @@ def test_CopyToWithMiniBatches(task):
"negative_node_pairs", "negative_node_pairs",
"node_pairs_with_labels", "node_pairs_with_labels",
] ]
elif task == "extra_attrs":
copied_attrs = [
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"seed_nodes",
]
def test_data_device(datapipe): def test_data_device(datapipe):
for data in datapipe: for data in datapipe:
...@@ -148,11 +158,16 @@ def test_CopyToWithMiniBatches(task): ...@@ -148,11 +158,16 @@ def test_CopyToWithMiniBatches(task):
else: else:
assert var.device.type == "cpu" assert var.device.type == "cpu"
if task == "extra_attrs":
extra_attrs = ["seed_nodes"]
else:
extra_attrs = None
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda")) test_data_device(gb.CopyTo(datapipe, "cuda", extra_attrs))
# Invoke CopyTo via functional form. # Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda")) test_data_device(datapipe.copy_to("cuda", extra_attrs))
def test_etype_tuple_to_str(): def test_etype_tuple_to_str():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment