Unverified Commit cac25f63 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix compatibility with PyTorch 1.10 (#3454)

parent 9067565a
...@@ -27,7 +27,7 @@ def data_type_dict(): ...@@ -27,7 +27,7 @@ def data_type_dict():
'int16' : th.int16, 'int16' : th.int16,
'int32' : th.int32, 'int32' : th.int32,
'int64' : th.int64, 'int64' : th.int64,
'bool' : th.uint8} 'bool' : th.bool}
def cpu(): def cpu():
return th.device('cpu') return th.device('cpu')
...@@ -330,7 +330,13 @@ def zerocopy_to_numpy(input): ...@@ -330,7 +330,13 @@ def zerocopy_to_numpy(input):
def zerocopy_from_numpy(np_array): def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array) return th.as_tensor(np_array)
def zerocopy_to_dgl_ndarray(data): if LooseVersion(th.__version__) >= LooseVersion("1.10.0"):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
else:
def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous())) return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment