"git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "fca97a63fd6b752f601404b592021f217265f8d9"
Unverified Commit 5b51e968 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Fix graph pinning and add tests (#6864)

parent 22a2513d
...@@ -971,7 +971,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -971,7 +971,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
def _pin(x): def _pin(x):
return x.pinned_memory() if hasattr(x, "pinned_memory") else x return x.pin_memory() if hasattr(x, "pin_memory") else x
self._apply_to_members(_pin) self._apply_to_members(_pin)
......
...@@ -1511,11 +1511,7 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1511,11 +1511,7 @@ def test_from_dglgraph_heterogeneous():
} }
@unittest.skipIf( def create_fused_csc_sampling_graph():
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
# Initialize data. # Initialize data.
total_num_nodes = 10 total_num_nodes = 10
total_num_edges = 9 total_num_edges = 9
...@@ -1541,7 +1537,7 @@ def test_csc_sampling_graph_to_device(): ...@@ -1541,7 +1537,7 @@ def test_csc_sampling_graph_to_device():
} }
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph( return gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
...@@ -1551,6 +1547,15 @@ def test_csc_sampling_graph_to_device(): ...@@ -1551,6 +1547,15 @@ def test_csc_sampling_graph_to_device():
edge_type_to_id=etypes, edge_type_to_id=etypes,
) )
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()
# Copy to device. # Copy to device.
graph = graph.to("cuda") graph = graph.to("cuda")
...@@ -1564,6 +1569,27 @@ def test_csc_sampling_graph_to_device(): ...@@ -1564,6 +1569,27 @@ def test_csc_sampling_graph_to_device():
assert graph.edge_attributes[key].device.type == "cuda" assert graph.edge_attributes[key].device.type == "cuda"
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Tests for pinned memory are only meaningful on GPU.",
)
def test_csc_sampling_graph_to_pinned_memory():
# Construct FusedCSCSamplingGraph.
graph = create_fused_csc_sampling_graph()
# Copy to pinned_memory in-place.
graph.pin_memory_()
# Check.
assert graph.csc_indptr.is_pinned()
assert graph.indices.is_pinned()
assert graph.node_type_offset.is_pinned()
assert graph.type_per_edge.is_pinned()
assert graph.csc_indptr.is_pinned()
for key in graph.edge_attributes:
assert graph.edge_attributes[key].is_pinned()
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment