Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8e73c75f
Unverified
Commit
8e73c75f
authored
May 09, 2023
by
Xin Yao
Committed by
GitHub
May 09, 2023
Browse files
[Fix] Fix `tensor.storage()` deprecation warning (#5656)
parent
c4c9b830
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
25 deletions
+19
-25
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+18
-12
python/dgl/dataloading/dataloader.py
python/dgl/dataloading/dataloader.py
+1
-13
No files found.
python/dgl/backend/pytorch/tensor.py
View file @
8e73c75f
...
...
@@ -12,8 +12,8 @@ from ... import ndarray as nd
from
...function.base
import
TargetCode
from
...utils
import
version
if
version
.
parse
(
th
.
__version__
)
<
version
.
parse
(
"1.
9
.0"
):
raise
RuntimeError
(
"DGL requires PyTorch >= 1.
9
.0"
)
if
version
.
parse
(
th
.
__version__
)
<
version
.
parse
(
"1.
12
.0"
):
raise
RuntimeError
(
"DGL requires PyTorch >= 1.
12
.0"
)
def
data_type_dict
():
...
...
@@ -428,17 +428,25 @@ def zerocopy_from_numpy(np_array):
return
th
.
as_tensor
(
np_array
)
if
version
.
parse
(
th
.
__version__
)
>=
version
.
parse
(
"1.10.0"
):
def
zerocopy_to_dgl_ndarray
(
data
):
if
data
.
dtype
==
th
.
bool
:
data
=
data
.
byte
()
return
nd
.
from_dlpack
(
dlpack
.
to_dlpack
(
data
.
contiguous
()))
def
zerocopy_to_dgl_ndarray
(
data
):
if
data
.
dtype
==
th
.
bool
:
data
=
data
.
byte
()
return
nd
.
from_dlpack
(
dlpack
.
to_dlpack
(
data
.
contiguous
()))
if
version
.
parse
(
th
.
__version__
)
>=
version
.
parse
(
"2.0.0"
):
def
check_is_view
(
input
):
assert
(
input
.
data_ptr
()
==
input
.
untyped_storage
().
data_ptr
()
),
"Cannot convert view tensors to dgl ndarray for write."
else
:
def
zerocopy_to_dgl_ndarray
(
data
):
return
nd
.
from_dlpack
(
dlpack
.
to_dlpack
(
data
.
contiguous
()))
def
check_is_view
(
input
):
assert
(
input
.
data_ptr
()
==
input
.
_storage
().
data_ptr
()
),
"Cannot convert view tensors to dgl ndarray for write."
def
zerocopy_to_dgl_ndarray_for_write
(
input
):
...
...
@@ -446,9 +454,7 @@ def zerocopy_to_dgl_ndarray_for_write(input):
"Cannot convert non-contiguous tensors "
"to dgl ndarray for write. Call .to_contiguous() first."
)
assert
input
.
numel
()
==
input
.
storage
().
size
(),
(
"Cannot convert view "
"tensors to dgl ndarray for write."
)
check_is_view
(
input
)
return
zerocopy_to_dgl_ndarray
(
input
)
...
...
python/dgl/dataloading/dataloader.py
View file @
8e73c75f
...
...
@@ -34,10 +34,8 @@ from ..utils import (
recursive_apply
,
recursive_apply_pair
,
set_num_threads
,
version
,
)
PYTORCH_VER
=
version
.
parse
(
torch
.
__version__
)
PYTHON_EXIT_STATUS
=
False
...
...
@@ -87,17 +85,7 @@ class _TensorizedDatasetIter(object):
# convert the type-ID pairs to dictionary
type_ids
=
batch
[:,
0
]
indices
=
batch
[:,
1
]
if
PYTORCH_VER
>=
version
.
parse
(
"1.10.0"
):
_
,
type_ids_sortidx
=
torch
.
sort
(
type_ids
,
stable
=
True
)
else
:
if
not
self
.
shuffle
:
dgl_warning
(
"The current output_nodes are out of order even if set shuffle "
"to False in Dataloader, the reason is that the current version "
"of torch dose not support stable sort. "
"Please update torch to 1.10.0 or higher to fix it."
)
type_ids_sortidx
=
torch
.
argsort
(
type_ids
)
_
,
type_ids_sortidx
=
torch
.
sort
(
type_ids
,
stable
=
True
)
type_ids
=
type_ids
[
type_ids_sortidx
]
indices
=
indices
[
type_ids_sortidx
]
type_id_uniq
,
type_id_count
=
torch
.
unique_consecutive
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment