Unverified Commit 0cb309a1 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Refine `CopyTo` (#6791)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent d780f6d0
......@@ -102,7 +102,9 @@ def apply_to(x, device):
@functional_datapipe("copy_to")
class CopyTo(IterDataPipe):
"""DataPipe that transfers each element yielded from the previous DataPipe
to the given device.
to the given device. For MiniBatch, only the related attributes
(automatically inferred) will be transferred by default. If you want to
transfer any other attributes, indicate them in the `extra_attrs`.
Functional name: :obj:`copy_to`.
......@@ -119,16 +121,29 @@ class CopyTo(IterDataPipe):
The DataPipe.
device : torch.device
The PyTorch CUDA device.
extra_attrs: List[string]
The extra attributes in the MiniBatch you want to be carried to the
specific device.
"""
def __init__(self, datapipe, device):
def __init__(self, datapipe, device, extra_attrs=None):
super().__init__()
self.datapipe = datapipe
self.device = device
self.extra_attrs = extra_attrs
def __iter__(self):
for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device)
if self.extra_attrs is not None:
for attr in self.extra_attrs:
setattr(
data,
attr,
recursive_apply(
getattr(data, attr), apply_to, self.device
),
)
yield data
......
......@@ -433,25 +433,50 @@ class MiniBatch:
else:
return None
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
def to(self, device: torch.device): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
for attr in dir(self):
def apply_to(x, device):
return recursive_apply(x, lambda x: _to(x, device))
if self.seed_nodes is not None and self.compacted_node_pairs is None:
# Node related tasks.
transfer_attrs = [
"labels",
"sampled_subgraphs",
"node_features",
"edge_features",
]
if self.labels is None:
# Layerwise inference
transfer_attrs.append("seed_nodes")
elif self.seed_nodes is None and self.compacted_node_pairs is not None:
# Link/edge related tasks.
transfer_attrs = [
"labels",
"compacted_node_pairs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"sampled_subgraphs",
"node_features",
"edge_features",
]
else:
# Otherwise copy all the attributes to the device.
transfer_attrs = get_attributes(self)
for attr in transfer_attrs:
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
try:
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
except AttributeError:
continue
try:
# For read-only attributes such as blocks and
# node_pairs_with_labels, setattr will throw an AttributeError.
# We catch these exceptions and skip those attributes.
setattr(self, attr, apply_to(getattr(self, attr), device))
except AttributeError:
continue
return self
......
import re
import unittest
from collections.abc import Iterable, Mapping
import backend as F
......@@ -25,11 +26,44 @@ def test_CopyTo():
assert data.device.type == "cuda"
@pytest.mark.parametrize(
"task",
[
"node_classification",
"node_inference",
"link_prediction",
"edge_classification",
"other",
],
)
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches():
def test_CopyToWithMiniBatches(task):
N = 16
B = 2
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
if task == "node_classification":
itemset = gb.ItemSet(
(torch.arange(N), torch.arange(N)), names=("seed_nodes", "labels")
)
elif task == "node_inference":
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
elif task == "link_prediction":
itemset = gb.ItemSet(
(
torch.arange(2 * N).reshape(-1, 2),
torch.arange(3 * N).reshape(-1, 3),
),
names=("node_pairs", "negative_dsts"),
)
elif task == "edge_classification":
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "labels"),
)
else:
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "seed_nodes"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {}
......@@ -44,28 +78,75 @@ def test_CopyToWithMiniBatches():
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
datapipe = gb.FeatureFetcher(
datapipe,
feature_store,
["a"],
)
if task != "node_inference":
datapipe = gb.FeatureFetcher(
datapipe,
feature_store,
["a"],
)
if task == "node_classification":
copied_attrs = [
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
]
elif task == "node_inference":
copied_attrs = [
"seed_nodes",
"sampled_subgraphs",
"blocks",
"labels",
]
elif task == "link_prediction":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"sampled_subgraphs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
def test_data_device(datapipe):
for data in datapipe:
for attr in dir(data):
var = getattr(data, attr)
if isinstance(var, Mapping):
var = var[next(iter(var))]
elif isinstance(var, Iterable):
var = next(iter(var))
if (
not callable(var)
and not attr.startswith("__")
and hasattr(var, "device")
and var is not None
):
assert var.device.type == "cuda"
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda"))
if task == "other":
assert var.device.type == "cuda"
else:
if attr in copied_attrs:
assert var.device.type == "cuda"
else:
assert var.device.type == "cpu"
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment