Unverified Commit 8e73c75f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix] Fix `tensor.storage()` deprecation warning (#5656)

parent c4c9b830
...@@ -12,8 +12,8 @@ from ... import ndarray as nd ...@@ -12,8 +12,8 @@ from ... import ndarray as nd
from ...function.base import TargetCode from ...function.base import TargetCode
from ...utils import version from ...utils import version
if version.parse(th.__version__) < version.parse("1.9.0"): if version.parse(th.__version__) < version.parse("1.12.0"):
raise RuntimeError("DGL requires PyTorch >= 1.9.0") raise RuntimeError("DGL requires PyTorch >= 1.12.0")
def data_type_dict(): def data_type_dict():
...@@ -428,17 +428,25 @@ def zerocopy_from_numpy(np_array): ...@@ -428,17 +428,25 @@ def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array) return th.as_tensor(np_array)
if version.parse(th.__version__) >= version.parse("1.10.0"): def zerocopy_to_dgl_ndarray(data):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool: if data.dtype == th.bool:
data = data.byte() data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous())) return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
if version.parse(th.__version__) >= version.parse("2.0.0"):
def check_is_view(input):
assert (
input.data_ptr() == input.untyped_storage().data_ptr()
), "Cannot convert view tensors to dgl ndarray for write."
else: else:
def zerocopy_to_dgl_ndarray(data): def check_is_view(input):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous())) assert (
input.data_ptr() == input._storage().data_ptr()
), "Cannot convert view tensors to dgl ndarray for write."
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
...@@ -446,9 +454,7 @@ def zerocopy_to_dgl_ndarray_for_write(input): ...@@ -446,9 +454,7 @@ def zerocopy_to_dgl_ndarray_for_write(input):
"Cannot convert non-contiguous tensors " "Cannot convert non-contiguous tensors "
"to dgl ndarray for write. Call .to_contiguous() first." "to dgl ndarray for write. Call .to_contiguous() first."
) )
assert input.numel() == input.storage().size(), ( check_is_view(input)
"Cannot convert view " "tensors to dgl ndarray for write."
)
return zerocopy_to_dgl_ndarray(input) return zerocopy_to_dgl_ndarray(input)
......
...@@ -34,10 +34,8 @@ from ..utils import ( ...@@ -34,10 +34,8 @@ from ..utils import (
recursive_apply, recursive_apply,
recursive_apply_pair, recursive_apply_pair,
set_num_threads, set_num_threads,
version,
) )
PYTORCH_VER = version.parse(torch.__version__)
PYTHON_EXIT_STATUS = False PYTHON_EXIT_STATUS = False
...@@ -87,17 +85,7 @@ class _TensorizedDatasetIter(object): ...@@ -87,17 +85,7 @@ class _TensorizedDatasetIter(object):
# convert the type-ID pairs to dictionary # convert the type-ID pairs to dictionary
type_ids = batch[:, 0] type_ids = batch[:, 0]
indices = batch[:, 1] indices = batch[:, 1]
if PYTORCH_VER >= version.parse("1.10.0"):
_, type_ids_sortidx = torch.sort(type_ids, stable=True) _, type_ids_sortidx = torch.sort(type_ids, stable=True)
else:
if not self.shuffle:
dgl_warning(
"The current output_nodes are out of order even if set shuffle "
"to False in Dataloader, the reason is that the current version "
"of torch dose not support stable sort. "
"Please update torch to 1.10.0 or higher to fix it."
)
type_ids_sortidx = torch.argsort(type_ids)
type_ids = type_ids[type_ids_sortidx] type_ids = type_ids[type_ids_sortidx]
indices = indices[type_ids_sortidx] indices = indices[type_ids_sortidx]
type_id_uniq, type_id_count = torch.unique_consecutive( type_id_uniq, type_id_count = torch.unique_consecutive(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment