Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f051a5cc
Unverified
Commit
f051a5cc
authored
Jan 31, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 31, 2024
Browse files
[GraphBolt][CUDA] Add `is_pinned` method to `TorchBasedFeatureStore`. (#7041)
parent
e7c3454f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
python/dgl/graphbolt/impl/torch_based_feature_store.py
python/dgl/graphbolt/impl/torch_based_feature_store.py
+8
-0
tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
.../pytorch/graphbolt/impl/test_torch_based_feature_store.py
+3
-6
No files found.
python/dgl/graphbolt/impl/torch_based_feature_store.py
View file @
f051a5cc
...
@@ -198,6 +198,10 @@ class TorchBasedFeature(Feature):
...
@@ -198,6 +198,10 @@ class TorchBasedFeature(Feature):
self
.
_is_inplace_pinned
.
add
(
x
)
self
.
_is_inplace_pinned
.
add
(
x
)
def
is_pinned
(
self
):
"""Returns True if the stored feature is pinned."""
return
self
.
_tensor
.
is_pinned
()
def
to
(
self
,
device
):
# pylint: disable=invalid-name
def
to
(
self
,
device
):
# pylint: disable=invalid-name
"""Copy `TorchBasedFeature` to the specified device."""
"""Copy `TorchBasedFeature` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory.
# copy.copy is a shallow copy so it does not copy tensor memory.
...
@@ -293,6 +297,10 @@ class TorchBasedFeatureStore(BasicFeatureStore):
...
@@ -293,6 +297,10 @@ class TorchBasedFeatureStore(BasicFeatureStore):
for
feature
in
self
.
_features
.
values
():
for
feature
in
self
.
_features
.
values
():
feature
.
pin_memory_
()
feature
.
pin_memory_
()
def
is_pinned
(
self
):
"""Returns True if all the stored features are pinned."""
return
all
(
feature
.
is_pinned
()
for
feature
in
self
.
_features
.
values
())
def
to
(
self
,
device
):
# pylint: disable=invalid-name
def
to
(
self
,
device
):
# pylint: disable=invalid-name
"""Copy `TorchBasedFeatureStore` to the specified device."""
"""Copy `TorchBasedFeatureStore` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory.
# copy.copy is a shallow copy so it does not copy tensor memory.
...
...
tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
View file @
f051a5cc
...
@@ -136,11 +136,6 @@ def test_torch_based_feature(in_memory):
...
@@ -136,11 +136,6 @@ def test_torch_based_feature(in_memory):
feature_a
=
feature_b
=
None
feature_a
=
feature_b
=
None
def
is_feature_store_pinned
(
store
):
for
feature
in
store
.
_features
.
values
():
assert
feature
.
_tensor
.
is_pinned
()
def
is_feature_store_on_cuda
(
store
):
def
is_feature_store_on_cuda
(
store
):
for
feature
in
store
.
_features
.
values
():
for
feature
in
store
.
_features
.
values
():
assert
feature
.
_tensor
.
is_cuda
assert
feature
.
_tensor
.
is_cuda
...
@@ -181,7 +176,7 @@ def test_feature_store_to_device(device):
...
@@ -181,7 +176,7 @@ def test_feature_store_to_device(device):
feature_store
=
gb
.
TorchBasedFeatureStore
(
feature_data
)
feature_store
=
gb
.
TorchBasedFeatureStore
(
feature_data
)
feature_store2
=
feature_store
.
to
(
device
)
feature_store2
=
feature_store
.
to
(
device
)
if
device
==
"pinned"
:
if
device
==
"pinned"
:
is_
feature_store_pinned
(
feature_store2
)
assert
feature_store
2
.
is
_pinned
()
elif
device
==
"cuda"
:
elif
device
==
"cuda"
:
is_feature_store_on_cuda
(
feature_store2
)
is_feature_store_on_cuda
(
feature_store2
)
...
@@ -228,6 +223,8 @@ def test_torch_based_pinned_feature(dtype, idtype, shape, in_place):
...
@@ -228,6 +223,8 @@ def test_torch_based_pinned_feature(dtype, idtype, shape, in_place):
else
:
else
:
feature
=
feature
.
to
(
"pinned"
)
feature
=
feature
.
to
(
"pinned"
)
assert
feature
.
is_pinned
()
# Test read entire pinned feature, the result should be on cuda.
# Test read entire pinned feature, the result should be on cuda.
assert
torch
.
equal
(
feature
.
read
(),
test_tensor_cuda
)
assert
torch
.
equal
(
feature
.
read
(),
test_tensor_cuda
)
assert
feature
.
read
().
is_cuda
assert
feature
.
read
().
is_cuda
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment