Unverified Commit 873d9bb3 authored by Yitong Huang's avatar Yitong Huang Committed by GitHub
Browse files

Make torch xla available on GPU (#29334)



* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------
Co-authored-by: default avataranw90 <ang868@gmail.com>
Co-authored-by: default avatarwangang.wa <wangang.wa@alibaba-inc.com>
parent 9a3f4d4d
......@@ -39,13 +39,13 @@ from torch.utils.data.distributed import DistributedSampler
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
......@@ -179,7 +179,7 @@ def nested_detach(tensors):
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
......
......@@ -37,7 +37,7 @@ from .utils import (
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
requires_backends,
)
......@@ -340,7 +340,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
if is_torch_tpu_available(check_device=True):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0
......@@ -351,7 +351,7 @@ def total_processes_number(local_rank):
"""
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
if is_torch_tpu_available(check_device=True):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
return xm.xrt_world_size()
......
......@@ -49,7 +49,7 @@ from .utils import (
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
......@@ -74,7 +74,7 @@ if is_accelerate_available():
from .trainer_pt_utils import AcceleratorConfig
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if is_torch_neuroncore_available(check_device=False):
......@@ -130,7 +130,9 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]:
"""
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
"""
if is_torch_tpu_available():
if is_torch_xla_available():
if device.type == "cpu":
return "CPU"
return xm.xla_real_devices([device])[0].split(":")[0]
return None
......@@ -1475,7 +1477,7 @@ class TrainingArguments:
self.half_precision_backend = self.fp16_backend
if self.bf16 or self.bf16_full_eval:
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available():
# cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.use_cpu:
......@@ -1530,7 +1532,7 @@ class TrainingArguments:
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
......@@ -1544,7 +1546,7 @@ class TrainingArguments:
and (self.device.type != "cuda")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
......@@ -1694,7 +1696,8 @@ class TrainingArguments:
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {})
# Copy the config to avoid modifying the original config (which may be used for JSON serialization)
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
# apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
......@@ -1948,7 +1951,7 @@ class TrainingArguments:
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if is_torch_tpu_available():
if is_torch_xla_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
......@@ -2029,7 +2032,7 @@ class TrainingArguments:
- `ParallelMode.TPU`: several TPU cores.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
if is_torch_xla_available():
return ParallelMode.TPU
elif is_sagemaker_mp_enabled():
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
......@@ -2180,7 +2183,7 @@ class TrainingArguments:
# tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available():
if is_torch_xla_available():
xm.rendezvous(desc)
else:
dist.barrier()
......@@ -2189,7 +2192,7 @@ class TrainingArguments:
if is_main_process:
# the wait is over
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
if is_torch_tpu_available():
if is_torch_xla_available():
xm.rendezvous(desc)
else:
dist.barrier()
......
......@@ -189,6 +189,7 @@ from .import_utils import (
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdistx_available,
......
......@@ -62,6 +62,9 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
......@@ -249,6 +252,13 @@ if _torch_available:
)
_torch_xla_available = False
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
if _torch_xla_available:
logger.info(f"Torch XLA version {_torch_xla_version} available.")
def is_kenlm_available():
return _kenlm_available
......@@ -484,6 +494,12 @@ def is_g2p_en_available():
@lru_cache()
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if not _torch_available:
return False
if importlib.util.find_spec("torch_xla") is not None:
......@@ -500,10 +516,31 @@ def is_torch_tpu_available(check_device=True):
return False
@lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
"""
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
"""
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
if not _torch_xla_available:
return False
import torch_xla
if check_is_gpu:
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
elif check_is_tpu:
return torch_xla.runtime.device_type() == "TPU"
return True
@lru_cache()
def is_torch_neuroncore_available(check_device=True):
if importlib.util.find_spec("torch_neuronx") is not None:
return is_torch_tpu_available(check_device)
return is_torch_xla_available()
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment