Unverified Commit 0cb309a1 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Refine `CopyTo` (#6791)


Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent d780f6d0
...@@ -102,7 +102,9 @@ def apply_to(x, device): ...@@ -102,7 +102,9 @@ def apply_to(x, device):
@functional_datapipe("copy_to") @functional_datapipe("copy_to")
class CopyTo(IterDataPipe): class CopyTo(IterDataPipe):
"""DataPipe that transfers each element yielded from the previous DataPipe """DataPipe that transfers each element yielded from the previous DataPipe
to the given device. to the given device. For MiniBatch, only the related attributes
(automatically inferred) will be transferred by default. If you want to
transfer any other attributes, indicate them in the `extra_attrs`.
Functional name: :obj:`copy_to`. Functional name: :obj:`copy_to`.
...@@ -119,16 +121,29 @@ class CopyTo(IterDataPipe): ...@@ -119,16 +121,29 @@ class CopyTo(IterDataPipe):
The DataPipe. The DataPipe.
device : torch.device device : torch.device
The PyTorch CUDA device. The PyTorch CUDA device.
extra_attrs: List[string]
The extra attributes in the MiniBatch you want to be carried to the
specific device.
""" """
def __init__(self, datapipe, device): def __init__(self, datapipe, device, extra_attrs=None):
super().__init__() super().__init__()
self.datapipe = datapipe self.datapipe = datapipe
self.device = device self.device = device
self.extra_attrs = extra_attrs
def __iter__(self): def __iter__(self):
for data in self.datapipe: for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device) data = recursive_apply(data, apply_to, self.device)
if self.extra_attrs is not None:
for attr in self.extra_attrs:
setattr(
data,
attr,
recursive_apply(
getattr(data, attr), apply_to, self.device
),
)
yield data yield data
......
...@@ -433,23 +433,48 @@ class MiniBatch: ...@@ -433,23 +433,48 @@ class MiniBatch:
else: else:
return None return None
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection.""" """Copy `MiniBatch` to the specified device using reflection."""
def _to(x, device): def _to(x, device):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
for attr in dir(self): def apply_to(x, device):
return recursive_apply(x, lambda x: _to(x, device))
if self.seed_nodes is not None and self.compacted_node_pairs is None:
# Node related tasks.
transfer_attrs = [
"labels",
"sampled_subgraphs",
"node_features",
"edge_features",
]
if self.labels is None:
# Layerwise inference
transfer_attrs.append("seed_nodes")
elif self.seed_nodes is None and self.compacted_node_pairs is not None:
# Link/edge related tasks.
transfer_attrs = [
"labels",
"compacted_node_pairs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"sampled_subgraphs",
"node_features",
"edge_features",
]
else:
# Otherwise copy all the attributes to the device.
transfer_attrs = get_attributes(self)
for attr in transfer_attrs:
# Only copy member variables. # Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
try: try:
setattr( # For read-only attributes such as blocks and
self, # node_pairs_with_labels, setattr will throw an AttributeError.
attr, # We catch these exceptions and skip those attributes.
recursive_apply( setattr(self, attr, apply_to(getattr(self, attr), device))
getattr(self, attr), lambda x: _to(x, device)
),
)
except AttributeError: except AttributeError:
continue continue
......
import re import re
import unittest import unittest
from collections.abc import Iterable, Mapping
import backend as F import backend as F
...@@ -25,11 +26,44 @@ def test_CopyTo(): ...@@ -25,11 +26,44 @@ def test_CopyTo():
assert data.device.type == "cuda" assert data.device.type == "cuda"
@pytest.mark.parametrize(
"task",
[
"node_classification",
"node_inference",
"link_prediction",
"edge_classification",
"other",
],
)
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches(): def test_CopyToWithMiniBatches(task):
N = 16 N = 16
B = 2 B = 2
if task == "node_classification":
itemset = gb.ItemSet(
(torch.arange(N), torch.arange(N)), names=("seed_nodes", "labels")
)
elif task == "node_inference":
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes") itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
elif task == "link_prediction":
itemset = gb.ItemSet(
(
torch.arange(2 * N).reshape(-1, 2),
torch.arange(3 * N).reshape(-1, 3),
),
names=("node_pairs", "negative_dsts"),
)
elif task == "edge_classification":
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "labels"),
)
else:
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "seed_nodes"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)
features = {} features = {}
...@@ -44,28 +78,75 @@ def test_CopyToWithMiniBatches(): ...@@ -44,28 +78,75 @@ def test_CopyToWithMiniBatches():
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
if task != "node_inference":
datapipe = gb.FeatureFetcher( datapipe = gb.FeatureFetcher(
datapipe, datapipe,
feature_store, feature_store,
["a"], ["a"],
) )
if task == "node_classification":
copied_attrs = [
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
]
elif task == "node_inference":
copied_attrs = [
"seed_nodes",
"sampled_subgraphs",
"blocks",
"labels",
]
elif task == "link_prediction":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"sampled_subgraphs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
def test_data_device(datapipe): def test_data_device(datapipe):
for data in datapipe: for data in datapipe:
for attr in dir(data): for attr in dir(data):
var = getattr(data, attr) var = getattr(data, attr)
if isinstance(var, Mapping):
var = var[next(iter(var))]
elif isinstance(var, Iterable):
var = next(iter(var))
if ( if (
not callable(var) not callable(var)
and not attr.startswith("__") and not attr.startswith("__")
and hasattr(var, "device") and hasattr(var, "device")
and var is not None
): ):
if task == "other":
assert var.device.type == "cuda" assert var.device.type == "cuda"
else:
# Invoke CopyTo via class constructor. if attr in copied_attrs:
test_data_device(gb.CopyTo(datapipe, "cuda")) assert var.device.type == "cuda"
else:
# Invoke CopyTo via functional form. assert var.device.type == "cpu"
test_data_device(datapipe.copy_to("cuda"))
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda")) test_data_device(gb.CopyTo(datapipe, "cuda"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment