Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a4a744f3
Commit
a4a744f3
authored
May 16, 2025
by
limm
Browse files
block the _driver_version parameter
parent
82d3aa12
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
apex/transformer/testing/distributed_test_base.py
apex/transformer/testing/distributed_test_base.py
+8
-8
No files found.
apex/transformer/testing/distributed_test_base.py
View file @
a4a744f3
...
...
@@ -14,9 +14,9 @@ from apex.transformer._ucc_util import HAS_UCC
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
=
Version
(
"470.42.01"
)
_driver_version
=
None
if
torch
.
cuda
.
is_available
():
_driver_version
=
parse
(
collect_env
.
get_nvidia_driver_version
(
collect_env
.
run
))
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
=
_driver_version
is
not
None
and
_driver_version
>=
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
#
if torch.cuda.is_available():
#
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
#
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
class
DistributedTestBase
(
common_distributed
.
MultiProcessTestCase
):
...
...
@@ -85,11 +85,11 @@ class NcclDistributedTestBase(DistributedTestBase):
HAS_UCC
,
"Requires either torch ucc or pytorch build from source with native ucc installed and enabled"
,
)
@
unittest
.
skipUnless
(
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
,
f
"`torch_ucc` requires NVIDIA driver >=
{
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
}
but
{
_driver_version
}
found. "
"See https://github.com/openucx/ucc/issues/496"
,
)
#
@unittest.skipUnless(
#
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER,
#
f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. "
#
"See https://github.com/openucx/ucc/issues/496",
#
)
class
UccDistributedTestBase
(
DistributedTestBase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment