Unverified Commit 08569139 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `.pin_memory_()` to `FusedCSCSamplingGraph` (#6839)

parent e7f0c3a1
...@@ -948,30 +948,32 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -948,30 +948,32 @@ class FusedCSCSamplingGraph(SamplingGraph):
self._c_csc_graph.copy_to_shared_memory(shared_memory_name), self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
) )
def _apply_to_members(self, fn):
"""Apply passed fn to all members of `FusedCSCSamplingGraph`."""
self.csc_indptr = recursive_apply(self.csc_indptr, fn)
self.indices = recursive_apply(self.indices, fn)
self.node_type_offset = recursive_apply(self.node_type_offset, fn)
self.type_per_edge = recursive_apply(self.type_per_edge, fn)
self.node_attributes = recursive_apply(self.node_attributes, fn)
self.edge_attributes = recursive_apply(self.edge_attributes, fn)
return self
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `FusedCSCSamplingGraph` to the specified device.""" """Copy `FusedCSCSamplingGraph` to the specified device."""
def _to(x, device): def _to(x):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
self.csc_indptr = recursive_apply( return self._apply_to_members(_to)
self.csc_indptr, lambda x: _to(x, device)
)
self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
self.node_type_offset = recursive_apply(
self.node_type_offset, lambda x: _to(x, device)
)
self.type_per_edge = recursive_apply(
self.type_per_edge, lambda x: _to(x, device)
)
self.node_attributes = recursive_apply(
self.node_attributes, lambda x: _to(x, device)
)
self.edge_attributes = recursive_apply(
self.edge_attributes, lambda x: _to(x, device)
)
return self def pin_memory_(self):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
def _pin(x):
return x.pinned_memory() if hasattr(x, "pinned_memory") else x
self._apply_to_members(_pin)
def fused_csc_sampling_graph( def fused_csc_sampling_graph(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment