Unverified Commit 5b51e968 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Fix graph pinning and add tests (#6864)

parent 22a2513d
......@@ -971,7 +971,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
def _pin(x):
return x.pinned_memory() if hasattr(x, "pinned_memory") else x
return x.pin_memory() if hasattr(x, "pin_memory") else x
self._apply_to_members(_pin)
......
......@@ -1511,11 +1511,7 @@ def test_from_dglgraph_heterogeneous():
}
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
def create_fused_csc_sampling_graph():
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
......@@ -1541,7 +1537,7 @@ def test_csc_sampling_graph_to_device():
}
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(
return gb.fused_csc_sampling_graph(
indptr,
indices,
edge_attributes=edge_attributes,
......@@ -1551,6 +1547,15 @@ def test_csc_sampling_graph_to_device():
edge_type_to_id=etypes,
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()
# Copy to device.
graph = graph.to("cuda")
......@@ -1564,6 +1569,27 @@ def test_csc_sampling_graph_to_device():
assert graph.edge_attributes[key].device.type == "cuda"
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Tests for pinned memory are only meaningful on GPU.",
)
def test_csc_sampling_graph_to_pinned_memory():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()
# Copy to pinned_memory in-place.
graph.pin_memory_()
# Check.
assert graph.csc_indptr.is_pinned()
assert graph.indices.is_pinned()
assert graph.node_type_offset.is_pinned()
assert graph.type_per_edge.is_pinned()
assert graph.csc_indptr.is_pinned()
for key in graph.edge_attributes:
assert graph.edge_attributes[key].is_pinned()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment