Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4015c5fe
"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2e7a28652ab7d15ebf0d0180e33fb0db3406c744"
Unverified
Commit
4015c5fe
authored
Jul 03, 2023
by
Quan (Andy) Gan
Committed by
GitHub
Jul 03, 2023
Browse files
[GraphBolt] Datapipe for copying to given device (#5918)
parent
2668d62f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
0 deletions
+54
-0
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-0
python/dgl/graphbolt/copy_to.py
python/dgl/graphbolt/copy_to.py
+39
-0
tests/python/pytorch/graphbolt/test_copy_to.py
tests/python/pytorch/graphbolt/test_copy_to.py
+14
-0
No files found.
python/dgl/graphbolt/__init__.py
View file @
4015c5fe
...
...
@@ -9,6 +9,7 @@ from .graph_storage import *
from
.itemset
import
*
from
.minibatch_sampler
import
*
from
.feature_store
import
*
from
.copy_to
import
*
from
.dataset
import
*
from
.subgraph_sampler
import
*
...
...
python/dgl/graphbolt/copy_to.py
0 → 100644
View file @
4015c5fe
"""Graph Bolt CUDA-related Data Pipelines"""
from
torchdata.datapipes.iter
import
IterDataPipe
from
..utils
import
recursive_apply
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
class
CopyTo
(
IterDataPipe
):
"""DataPipe that transfers each element yielded from the previous DataPipe
to the given device.
This is equivalent to
.. code:: python
for data in datapipe:
yield data.to(device)
Parameters
----------
datapipe : DataPipe
The DataPipe.
device : torch.device
The PyTorch CUDA device.
"""
def
__init__
(
self
,
datapipe
,
device
):
super
().
__init__
()
self
.
datapipe
=
datapipe
self
.
device
=
device
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
=
recursive_apply
(
data
,
_to
,
self
.
device
)
yield
data
tests/python/pytorch/graphbolt/test_copy_to.py
0 → 100644
View file @
4015c5fe
import
unittest
import
backend
as
F
import
dgl.graphbolt
import
torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
dp
=
dgl
.
graphbolt
.
MinibatchSampler
(
torch
.
randn
(
20
),
4
)
dp
=
dgl
.
graphbolt
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
assert
data
.
device
.
type
==
"cuda"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment