Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
57dea7f2
Commit
57dea7f2
authored
Aug 08, 2022
by
hubertlu-tw
Browse files
Fix the cuda-specific transformer utils for ROCm
parent
cb8b7a88
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
1 deletion
+4
-1
apex/transformer/testing/distributed_test_base.py
apex/transformer/testing/distributed_test_base.py
+4
-1
No files found.
apex/transformer/testing/distributed_test_base.py
View file @
57dea7f2
...
...
@@ -20,7 +20,10 @@ except ImportError:
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
=
Version
(
"470.42.01"
)
_driver_version
=
None
if
torch
.
cuda
.
is_available
():
_driver_version
=
parse
(
collect_env
.
get_nvidia_driver_version
(
collect_env
.
run
))
if
collect_env
.
get_nvidia_driver_version
(
collect_env
.
run
)
!=
None
:
_driver_version
=
parse
(
collect_env
.
get_nvidia_driver_version
(
collect_env
.
run
))
else
:
_driver_version
=
None
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
=
_driver_version
is
not
None
and
_driver_version
>=
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment