Unverified Commit e50be9a0 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Guard XLA version imports (#30167)

parent fbdb978e
......@@ -136,6 +136,7 @@ from .utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress,
PushToHubMixin,
can_return_loss,
......@@ -179,8 +180,14 @@ if is_datasets_available():
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
if is_sagemaker_mp_enabled():
......@@ -664,6 +671,8 @@ class Trainer:
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled:
if not IS_XLA_FSDPV2_POST_2_2:
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
......
......@@ -98,6 +98,7 @@ from .import_utils import (
USE_JAX,
USE_TF,
USE_TORCH,
XLA_FSDPV2_MIN_VERSION,
DummyObject,
OptionalDependencyNotAvailable,
_LazyModule,
......
......@@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment