Unverified Commit 8e73c75f authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Fix] Fix `tensor.storage()` deprecation warning (#5656)

parent c4c9b830
......@@ -12,8 +12,8 @@ from ... import ndarray as nd
from ...function.base import TargetCode
from ...utils import version
if version.parse(th.__version__) < version.parse("1.9.0"):
raise RuntimeError("DGL requires PyTorch >= 1.9.0")
if version.parse(th.__version__) < version.parse("1.12.0"):
raise RuntimeError("DGL requires PyTorch >= 1.12.0")
def data_type_dict():
......@@ -428,17 +428,25 @@ def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array)
if version.parse(th.__version__) >= version.parse("1.10.0"):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
if version.parse(th.__version__) >= version.parse("2.0.0"):
def check_is_view(input):
assert (
input.data_ptr() == input.untyped_storage().data_ptr()
), "Cannot convert view tensors to dgl ndarray for write."
else:
def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def check_is_view(input):
assert (
input.data_ptr() == input._storage().data_ptr()
), "Cannot convert view tensors to dgl ndarray for write."
def zerocopy_to_dgl_ndarray_for_write(input):
......@@ -446,9 +454,7 @@ def zerocopy_to_dgl_ndarray_for_write(input):
"Cannot convert non-contiguous tensors "
"to dgl ndarray for write. Call .to_contiguous() first."
)
assert input.numel() == input.storage().size(), (
"Cannot convert view " "tensors to dgl ndarray for write."
)
check_is_view(input)
return zerocopy_to_dgl_ndarray(input)
......
......@@ -34,10 +34,8 @@ from ..utils import (
recursive_apply,
recursive_apply_pair,
set_num_threads,
version,
)
PYTORCH_VER = version.parse(torch.__version__)
PYTHON_EXIT_STATUS = False
......@@ -87,17 +85,7 @@ class _TensorizedDatasetIter(object):
# convert the type-ID pairs to dictionary
type_ids = batch[:, 0]
indices = batch[:, 1]
if PYTORCH_VER >= version.parse("1.10.0"):
_, type_ids_sortidx = torch.sort(type_ids, stable=True)
else:
if not self.shuffle:
dgl_warning(
"The current output_nodes are out of order even if set shuffle "
"to False in Dataloader, the reason is that the current version "
"of torch dose not support stable sort. "
"Please update torch to 1.10.0 or higher to fix it."
)
type_ids_sortidx = torch.argsort(type_ids)
_, type_ids_sortidx = torch.sort(type_ids, stable=True)
type_ids = type_ids[type_ids_sortidx]
indices = indices[type_ids_sortidx]
type_id_uniq, type_id_count = torch.unique_consecutive(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment