Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e50be9a0
Unverified
Commit
e50be9a0
authored
Apr 11, 2024
by
Zach Mueller
Committed by
GitHub
Apr 11, 2024
Browse files
Guard XLA version imports (#30167)
parent
fbdb978e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
2 deletions
+13
-2
src/transformers/trainer.py
src/transformers/trainer.py
+11
-2
src/transformers/utils/__init__.py
src/transformers/utils/__init__.py
+1
-0
src/transformers/utils/import_utils.py
src/transformers/utils/import_utils.py
+1
-0
No files found.
src/transformers/trainer.py
View file @
e50be9a0
...
...
@@ -136,6 +136,7 @@ from .utils import (
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
XLA_FSDPV2_MIN_VERSION
,
PushInProgress
,
PushToHubMixin
,
can_return_loss
,
...
...
@@ -179,8 +180,14 @@ if is_datasets_available():
if
is_torch_xla_available
():
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.metrics
as
met
import
torch_xla.distributed.spmd
as
xs
import
torch_xla.runtime
as
xr
from
torch_xla
import
__version__
as
XLA_VERSION
IS_XLA_FSDPV2_POST_2_2
=
version
.
parse
(
XLA_VERSION
)
>=
version
.
parse
(
XLA_FSDPV2_MIN_VERSION
)
if
IS_XLA_FSDPV2_POST_2_2
:
import
torch_xla.distributed.spmd
as
xs
import
torch_xla.runtime
as
xr
else
:
IS_XLA_FSDPV2_POST_2_2
=
False
if
is_sagemaker_mp_enabled
():
...
...
@@ -664,6 +671,8 @@ class Trainer:
self
.
is_fsdp_xla_v2_enabled
=
args
.
fsdp_config
.
get
(
"xla_fsdp_v2"
,
False
)
if
self
.
is_fsdp_xla_v2_enabled
:
if
not
IS_XLA_FSDPV2_POST_2_2
:
raise
ValueError
(
"FSDPv2 requires `torch_xla` 2.2 or higher."
)
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices
=
xr
.
global_runtime_device_count
()
...
...
src/transformers/utils/__init__.py
View file @
e50be9a0
...
...
@@ -98,6 +98,7 @@ from .import_utils import (
USE_JAX
,
USE_TF
,
USE_TORCH
,
XLA_FSDPV2_MIN_VERSION
,
DummyObject
,
OptionalDependencyNotAvailable
,
_LazyModule
,
...
...
src/transformers/utils/import_utils.py
View file @
e50be9a0
...
...
@@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION
=
"0.21.0"
FSDP_MIN_VERSION
=
"1.12.0"
XLA_FSDPV2_MIN_VERSION
=
"2.2.0"
_accelerate_available
,
_accelerate_version
=
_is_package_available
(
"accelerate"
,
return_version
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment