Unverified Commit e50be9a0 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Guard XLA version imports (#30167)

parent fbdb978e
...@@ -136,6 +136,7 @@ from .utils import ( ...@@ -136,6 +136,7 @@ from .utils import (
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress, PushInProgress,
PushToHubMixin, PushToHubMixin,
can_return_loss, can_return_loss,
...@@ -179,8 +180,14 @@ if is_datasets_available(): ...@@ -179,8 +180,14 @@ if is_datasets_available():
if is_torch_xla_available(): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs from torch_xla import __version__ as XLA_VERSION
import torch_xla.runtime as xr
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
...@@ -664,6 +671,8 @@ class Trainer: ...@@ -664,6 +671,8 @@ class Trainer:
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False) self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled: if self.is_fsdp_xla_v2_enabled:
if not IS_XLA_FSDPV2_POST_2_2:
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2. # Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count() num_devices = xr.global_runtime_device_count()
......
...@@ -98,6 +98,7 @@ from .import_utils import ( ...@@ -98,6 +98,7 @@ from .import_utils import (
USE_JAX, USE_JAX,
USE_TF, USE_TF,
USE_TORCH, USE_TORCH,
XLA_FSDPV2_MIN_VERSION,
DummyObject, DummyObject,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
......
...@@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") ...@@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION = "0.21.0" ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0" FSDP_MIN_VERSION = "1.12.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment