"...text-generation-inference.git" did not exist on "211b54ac41cae9a369f3d74bd6cc666ff4a0c526"
Unverified Commit cac25f63 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix compatibility with PyTorch 1.10 (#3454)

parent 9067565a
......@@ -27,7 +27,7 @@ def data_type_dict():
'int16' : th.int16,
'int32' : th.int32,
'int64' : th.int64,
'bool' : th.uint8}
'bool' : th.bool}
def cpu():
return th.device('cpu')
......@@ -330,8 +330,14 @@ def zerocopy_to_numpy(input):
def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array)
def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
if LooseVersion(th.__version__) >= LooseVersion("1.10.0"):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
else:
def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment