"tests/vscode:/vscode.git/clone" did not exist on "099b173f6f678d576a727ee4ad170599ec466f4e"
Unverified Commit 02443df1 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add to function for SampledSubgraph (#6480)

parent aff6b685
......@@ -76,7 +76,9 @@ def etype_str_to_tuple(c_etype):
return ret
def _to(x, device):
def apply_to(x, device):
"""Apply `to` function to object x only if it has `to`."""
return x.to(device) if hasattr(x, "to") else x
......@@ -107,5 +109,5 @@ class CopyTo(IterDataPipe):
def __iter__(self):
for data in self.datapipe:
data = recursive_apply(data, _to, self.device)
data = recursive_apply(data, apply_to, self.device)
yield data
......@@ -4,7 +4,9 @@ from typing import Dict, Tuple, Union
import torch
from .base import etype_str_to_tuple, isin
from dgl.utils import recursive_apply
from .base import apply_to, etype_str_to_tuple, isin
__all__ = ["SampledSubgraph"]
......@@ -189,6 +191,22 @@ class SampledSubgraph:
)
return calling_class(*_slice_subgraph(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection."""
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: apply_to(x, device)
),
)
return self
def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
u, v = node_pair
......
import unittest
import backend as F
import pytest
import torch
......@@ -132,3 +135,53 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_sampled_subgraph_to_device():
# Initialize data.
node_pairs = {
"A:relation:B": (
torch.tensor([0, 1, 2]),
torch.tensor([2, 1, 0]),
)
}
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
graph = subgraph.exclude_edges(edges_to_exclude)
# Copy to device.
graph = graph.to("cuda")
# Check.
for key in graph.node_pairs:
assert graph.node_pairs[key][0].device.type == "cuda"
assert graph.node_pairs[key][1].device.type == "cuda"
for key in graph.original_column_node_ids:
assert graph.original_column_node_ids[key].device.type == "cuda"
for key in graph.original_row_node_ids:
assert graph.original_row_node_ids[key].device.type == "cuda"
for key in graph.original_edge_ids:
assert graph.original_edge_ids[key].device.type == "cuda"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment