Commit a4a744f3 authored by limm's avatar limm
Browse files

block the _driver_version parameter

parent 82d3aa12
......@@ -14,9 +14,9 @@ from apex.transformer._ucc_util import HAS_UCC
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
_driver_version = None
if torch.cuda.is_available():
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
#if torch.cuda.is_available():
# _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
#HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class DistributedTestBase(common_distributed.MultiProcessTestCase):
......@@ -85,11 +85,11 @@ class NcclDistributedTestBase(DistributedTestBase):
HAS_UCC,
"Requires either torch ucc or pytorch build from source with native ucc installed and enabled",
)
@unittest.skipUnless(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
"See https://github.com/openucx/ucc/issues/496",
)
#@unittest.skipUnless(
# HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
# f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
# "See https://github.com/openucx/ucc/issues/496",
#)
class UccDistributedTestBase(DistributedTestBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment