Unverified Commit 617979d6 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add to function for DGLMiniBatch and MiniBatch (#6413)

parent 38448dac
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import dgl import dgl
from dgl.heterograph import DGLBlock from dgl.heterograph import DGLBlock
from dgl.utils import recursive_apply
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
...@@ -95,6 +96,25 @@ class DGLMiniBatch: ...@@ -95,6 +96,25 @@ class DGLMiniBatch:
given type. given type.
""" """
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
return self
@dataclass @dataclass
class MiniBatch: class MiniBatch:
...@@ -374,6 +394,25 @@ class MiniBatch: ...@@ -374,6 +394,25 @@ class MiniBatch:
} }
return minibatch return minibatch
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: _to(x, device)
),
)
return self
def _minibatch_str(minibatch: MiniBatch) -> str: def _minibatch_str(minibatch: MiniBatch) -> str:
final_str = "" final_str = ""
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
import backend as F import backend as F
import dgl.graphbolt as gb import dgl.graphbolt as gb
import gb_test_utils
import pytest import pytest
import torch import torch
...@@ -23,6 +24,58 @@ def test_CopyTo(): ...@@ -23,6 +24,58 @@ def test_CopyTo():
assert data.device.type == "cuda" assert data.device.type == "cuda"
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyToWithMiniBatches():
N = 16
B = 2
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(100, 0.15)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
feature_store = gb.BasicFeatureStore(features)
datapipe = gb.ItemSampler(itemset, batch_size=B)
datapipe = gb.NeighborSampler(
datapipe,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
datapipe = gb.FeatureFetcher(
datapipe,
feature_store,
["a"],
)
def test_data_device(datapipe):
for data in datapipe:
for attr in dir(data):
var = getattr(data, attr)
if (
not callable(var)
and not attr.startswith("__")
and hasattr(var, "device")
):
assert var.device.type == "cuda"
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda"))
# Test for DGLMiniBatch.
datapipe = gb.DGLMiniBatchConverter(datapipe)
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda"))
def test_etype_tuple_to_str(): def test_etype_tuple_to_str():
"""Convert etype from tuple to string.""" """Convert etype from tuple to string."""
# Test for expected input. # Test for expected input.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment