Unverified Commit bf3c008e authored by eqy's avatar eqy Committed by GitHub
Browse files

[UCC][TORCH_UCC]Do integer driver version comparison for UCC (#1411)

* Integer driver number comparison

* packaging
parent 3ff1a10f
import os import os
import sys import sys
import unittest import unittest
from packaging.version import Version, parse
import torch import torch
from torch import distributed as dist from torch import distributed as dist
...@@ -16,10 +17,10 @@ except ImportError: ...@@ -16,10 +17,10 @@ except ImportError:
HAS_TORCH_UCC = False HAS_TORCH_UCC = False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496 # NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = "470.42.01" _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
_driver_version = None _driver_version = None
if torch.cuda.is_available(): if torch.cuda.is_available():
_driver_version = collect_env.get_nvidia_driver_version(collect_env.run) _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
......
...@@ -3,3 +3,4 @@ tqdm>=4.28.1 ...@@ -3,3 +3,4 @@ tqdm>=4.28.1
numpy>=1.15.3 numpy>=1.15.3
PyYAML>=5.1 PyYAML>=5.1
pytest>=3.5.1 pytest>=3.5.1
packaging>=14.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment