Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
08569139
"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fdd3aabaef706d12280e87c6d68825573b7cfc7d"
Unverified
Commit
08569139
authored
Dec 28, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Dec 28, 2023
Browse files
[GraphBolt][CUDA] Add `.pin_memory_()` to `FusedCSCSamplingGraph` (#6839)
parent
e7f0c3a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
18 deletions
+20
-18
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+20
-18
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
08569139
...
@@ -948,30 +948,32 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -948,30 +948,32 @@ class FusedCSCSamplingGraph(SamplingGraph):
self
.
_c_csc_graph
.
copy_to_shared_memory
(
shared_memory_name
),
self
.
_c_csc_graph
.
copy_to_shared_memory
(
shared_memory_name
),
)
)
def
_apply_to_members
(
self
,
fn
):
"""Apply passed fn to all members of `FusedCSCSamplingGraph`."""
self
.
csc_indptr
=
recursive_apply
(
self
.
csc_indptr
,
fn
)
self
.
indices
=
recursive_apply
(
self
.
indices
,
fn
)
self
.
node_type_offset
=
recursive_apply
(
self
.
node_type_offset
,
fn
)
self
.
type_per_edge
=
recursive_apply
(
self
.
type_per_edge
,
fn
)
self
.
node_attributes
=
recursive_apply
(
self
.
node_attributes
,
fn
)
self
.
edge_attributes
=
recursive_apply
(
self
.
edge_attributes
,
fn
)
return
self
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `FusedCSCSamplingGraph` to the specified device."""
"""Copy `FusedCSCSamplingGraph` to the specified device."""
def
_to
(
x
,
device
):
def
_to
(
x
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
self
.
csc_indptr
=
recursive_apply
(
return
self
.
_apply_to_members
(
_to
)
self
.
csc_indptr
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
indices
=
recursive_apply
(
self
.
indices
,
lambda
x
:
_to
(
x
,
device
))
self
.
node_type_offset
=
recursive_apply
(
self
.
node_type_offset
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
type_per_edge
=
recursive_apply
(
self
.
type_per_edge
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
node_attributes
=
recursive_apply
(
self
.
node_attributes
,
lambda
x
:
_to
(
x
,
device
)
)
self
.
edge_attributes
=
recursive_apply
(
self
.
edge_attributes
,
lambda
x
:
_to
(
x
,
device
)
)
return
self
def
pin_memory_
(
self
):
"""Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
def
_pin
(
x
):
return
x
.
pinned_memory
()
if
hasattr
(
x
,
"pinned_memory"
)
else
x
self
.
_apply_to_members
(
_pin
)
def
fused_csc_sampling_graph
(
def
fused_csc_sampling_graph
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment