Unverified Commit 873d9bb3 authored by Yitong Huang's avatar Yitong Huang Committed by GitHub
Browse files

Make torch xla available on GPU (#29334)



* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------
Co-authored-by: default avataranw90 <ang868@gmail.com>
Co-authored-by: default avatarwangang.wa <wangang.wa@alibaba-inc.com>
parent 9a3f4d4d
...@@ -39,13 +39,13 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -39,13 +39,13 @@ from torch.utils.data.distributed import DistributedSampler
from .integrations.deepspeed import is_deepspeed_zero3_enabled from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
if is_training_run_on_sagemaker(): if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout)) logging.add_handler(StreamHandler(sys.stdout))
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
...@@ -179,7 +179,7 @@ def nested_detach(tensors): ...@@ -179,7 +179,7 @@ def nested_detach(tensors):
def nested_xla_mesh_reduce(tensors, name): def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available(): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
......
...@@ -37,7 +37,7 @@ from .utils import ( ...@@ -37,7 +37,7 @@ from .utils import (
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available, is_torch_mps_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
requires_backends, requires_backends,
) )
...@@ -340,7 +340,7 @@ def is_main_process(local_rank): ...@@ -340,7 +340,7 @@ def is_main_process(local_rank):
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
`local_rank`. `local_rank`.
""" """
if is_torch_tpu_available(check_device=True): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.get_ordinal() == 0 return xm.get_ordinal() == 0
...@@ -351,7 +351,7 @@ def total_processes_number(local_rank): ...@@ -351,7 +351,7 @@ def total_processes_number(local_rank):
""" """
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
""" """
if is_torch_tpu_available(check_device=True): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
return xm.xrt_world_size() return xm.xrt_world_size()
......
...@@ -49,7 +49,7 @@ from .utils import ( ...@@ -49,7 +49,7 @@ from .utils import (
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
logging, logging,
requires_backends, requires_backends,
...@@ -74,7 +74,7 @@ if is_accelerate_available(): ...@@ -74,7 +74,7 @@ if is_accelerate_available():
from .trainer_pt_utils import AcceleratorConfig from .trainer_pt_utils import AcceleratorConfig
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if is_torch_neuroncore_available(check_device=False): if is_torch_neuroncore_available(check_device=False):
...@@ -130,7 +130,9 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]: ...@@ -130,7 +130,9 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]:
""" """
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
""" """
if is_torch_tpu_available(): if is_torch_xla_available():
if device.type == "cpu":
return "CPU"
return xm.xla_real_devices([device])[0].split(":")[0] return xm.xla_real_devices([device])[0].split(":")[0]
return None return None
...@@ -1475,7 +1477,7 @@ class TrainingArguments: ...@@ -1475,7 +1477,7 @@ class TrainingArguments:
self.half_precision_backend = self.fp16_backend self.half_precision_backend = self.fp16_backend
if self.bf16 or self.bf16_full_eval: if self.bf16 or self.bf16_full_eval:
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available():
# cpu # cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.use_cpu: elif not self.use_cpu:
...@@ -1530,7 +1532,7 @@ class TrainingArguments: ...@@ -1530,7 +1532,7 @@ class TrainingArguments:
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval) and (self.fp16 or self.fp16_full_eval)
): ):
raise ValueError( raise ValueError(
...@@ -1544,7 +1546,7 @@ class TrainingArguments: ...@@ -1544,7 +1546,7 @@ class TrainingArguments:
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu") and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU") and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu") and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval) and (self.bf16 or self.bf16_full_eval)
...@@ -1694,7 +1696,8 @@ class TrainingArguments: ...@@ -1694,7 +1696,8 @@ class TrainingArguments:
if self.fsdp_config["xla"]: if self.fsdp_config["xla"]:
if len(self.fsdp) > 0: if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary # store XLA fsdp configuration parameters into a dictionary
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) # Copy the config to avoid modifying the original config (which may be used for JSON serialization)
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
# apply appropriate string to torch.dtype conversions for parameters # apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config: if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
...@@ -1948,7 +1951,7 @@ class TrainingArguments: ...@@ -1948,7 +1951,7 @@ class TrainingArguments:
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
) )
if is_torch_tpu_available(): if is_torch_xla_available():
device = self.distributed_state.device device = self.distributed_state.device
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
...@@ -2029,7 +2032,7 @@ class TrainingArguments: ...@@ -2029,7 +2032,7 @@ class TrainingArguments:
- `ParallelMode.TPU`: several TPU cores. - `ParallelMode.TPU`: several TPU cores.
""" """
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
if is_torch_tpu_available(): if is_torch_xla_available():
return ParallelMode.TPU return ParallelMode.TPU
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
return ParallelMode.SAGEMAKER_MODEL_PARALLEL return ParallelMode.SAGEMAKER_MODEL_PARALLEL
...@@ -2180,7 +2183,7 @@ class TrainingArguments: ...@@ -2180,7 +2183,7 @@ class TrainingArguments:
# tell all replicas to wait # tell all replicas to wait
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous(desc) xm.rendezvous(desc)
else: else:
dist.barrier() dist.barrier()
...@@ -2189,7 +2192,7 @@ class TrainingArguments: ...@@ -2189,7 +2192,7 @@ class TrainingArguments:
if is_main_process: if is_main_process:
# the wait is over # the wait is over
logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous(desc) xm.rendezvous(desc)
else: else:
dist.barrier() dist.barrier()
......
...@@ -189,6 +189,7 @@ from .import_utils import ( ...@@ -189,6 +189,7 @@ from .import_utils import (
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdistx_available, is_torchdistx_available,
......
...@@ -62,6 +62,9 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper() ...@@ -62,6 +62,9 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. # `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
...@@ -249,6 +252,13 @@ if _torch_available: ...@@ -249,6 +252,13 @@ if _torch_available:
) )
_torch_xla_available = False
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
if _torch_xla_available:
logger.info(f"Torch XLA version {_torch_xla_version} available.")
def is_kenlm_available(): def is_kenlm_available():
return _kenlm_available return _kenlm_available
...@@ -484,6 +494,12 @@ def is_g2p_en_available(): ...@@ -484,6 +494,12 @@ def is_g2p_en_available():
@lru_cache() @lru_cache()
def is_torch_tpu_available(check_device=True): def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if not _torch_available: if not _torch_available:
return False return False
if importlib.util.find_spec("torch_xla") is not None: if importlib.util.find_spec("torch_xla") is not None:
...@@ -500,10 +516,31 @@ def is_torch_tpu_available(check_device=True): ...@@ -500,10 +516,31 @@ def is_torch_tpu_available(check_device=True):
return False return False
@lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
"""
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
"""
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
if not _torch_xla_available:
return False
import torch_xla
if check_is_gpu:
return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
elif check_is_tpu:
return torch_xla.runtime.device_type() == "TPU"
return True
@lru_cache() @lru_cache()
def is_torch_neuroncore_available(check_device=True): def is_torch_neuroncore_available(check_device=True):
if importlib.util.find_spec("torch_neuronx") is not None: if importlib.util.find_spec("torch_neuronx") is not None:
return is_torch_tpu_available(check_device) return is_torch_xla_available()
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment