Unverified Commit bf3c008e authored by eqy's avatar eqy Committed by GitHub
Browse files

[UCC][TORCH_UCC]Do integer driver version comparison for UCC (#1411)

* Integer driver number comparison

* packaging
parent 3ff1a10f
import os
import sys
import unittest
from packaging.version import Version, parse
import torch
from torch import distributed as dist
......@@ -16,10 +17,10 @@ except ImportError:
HAS_TORCH_UCC = False
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = "470.42.01"
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
_driver_version = None
if torch.cuda.is_available():
_driver_version = collect_env.get_nvidia_driver_version(collect_env.run)
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
......
......@@ -3,3 +3,4 @@ tqdm>=4.28.1
numpy>=1.15.3
PyYAML>=5.1
pytest>=3.5.1
packaging>=14.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment