"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7361dccdc581147620bbd74a6d295cd92daf616"
Unverified Commit b81efb2b authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[PyTorch][Bugfix] Use uint8 instead of bool in pytorch to be compatible with...


[PyTorch][Bugfix] Use uint8 instead of bool in pytorch to be compatible with nightly version (#3406)

* Use uint8 instead of bool in pytorch

* Handle type aliases

* Fix syntax error
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent a47ab71d
...@@ -60,9 +60,16 @@ def load_backend(mod_name): ...@@ -60,9 +60,16 @@ def load_backend(mod_name):
# override data type dict function # override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict) setattr(thismod, 'data_type_dict', data_type_dict)
# for data types with aliases, treat the first listed type as
# the true one
rev_data_type_dict = {}
for k, v in data_type_dict.items():
if not v in rev_data_type_dict.keys():
rev_data_type_dict[v] = k
setattr(thismod, setattr(thismod,
'reverse_data_type_dict', 'reverse_data_type_dict',
{v: k for k, v in data_type_dict.items()}) rev_data_type_dict)
# log backend name # log backend name
setattr(thismod, 'backend_name', mod_name) setattr(thismod, 'backend_name', mod_name)
else: else:
......
...@@ -27,7 +27,7 @@ def data_type_dict(): ...@@ -27,7 +27,7 @@ def data_type_dict():
'int16' : th.int16, 'int16' : th.int16,
'int32' : th.int32, 'int32' : th.int32,
'int64' : th.int64, 'int64' : th.int64,
'bool' : th.bool} 'bool' : th.uint8}
def cpu(): def cpu():
return th.device('cpu') return th.device('cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment