Commit 57dea7f2 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Fix the cuda-specific transformer utils for ROCm

parent cb8b7a88
......@@ -20,7 +20,10 @@ except ImportError:
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
_driver_version = None
if torch.cuda.is_available():
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
if collect_env.get_nvidia_driver_version(collect_env.run) != None:
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
else:
_driver_version = None
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment