Unverified Commit 873d9bb3 authored by Yitong Huang's avatar Yitong Huang Committed by GitHub
Browse files

Make torch xla available on GPU (#29334)



* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------
Co-authored-by: default avataranw90 <ang868@gmail.com>
Co-authored-by: default avatarwangang.wa <wangang.wa@alibaba-inc.com>
parent 9a3f4d4d
......@@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU
- `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU
Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:
......
......@@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU
- `require_torch_xla` - as `require_torch` plus requires at least 1 TPU
Let's depict the GPU requirements in the following table:
......
......@@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
- `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
以下の表にGPUの要件を示します:
......
......@@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
- `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:
......
......@@ -32,7 +32,7 @@ from transformers.optimization import (
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
from transformers.utils import is_torch_xla_available
logger = logging.get_logger(__name__)
......@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
......
......@@ -46,7 +46,7 @@ from transformers import (
Trainer,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
......@@ -602,9 +602,9 @@ def main():
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)
......
......@@ -45,7 +45,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
......@@ -620,9 +620,9 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)
......
......@@ -21,7 +21,7 @@ import sys
from time import time
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, require_torch_tpu
from transformers.testing_utils import TestCasePlus, require_torch_xla
logging.basicConfig(level=logging.DEBUG)
......@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@require_torch_tpu
@require_torch_xla
class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self):
import xla_spawn
......
......@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
import math
import time
from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -24,13 +24,13 @@ import quant_trainer
import torch
from torch.utils.data import DataLoader
from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput
logger = logging.getLogger(__name__)
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
......
......@@ -1093,6 +1093,7 @@ _import_structure = {
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
"is_torch_xla_available",
"is_torch_xpu_available",
"is_vision_available",
"logging",
......@@ -5897,6 +5898,7 @@ if TYPE_CHECKING:
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchvision_available,
is_vision_available,
......
......@@ -20,7 +20,7 @@ from typing import Tuple
from ..utils import (
cached_property,
is_torch_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
......@@ -31,7 +31,7 @@ from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
......@@ -88,7 +88,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
if not self.cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_torch_tpu_available():
elif is_torch_xla_available():
device = xm.xla_device()
n_gpu = 0
elif is_torch_xpu_available():
......@@ -101,7 +101,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
@property
def is_tpu(self):
return is_torch_tpu_available() and self.tpu
return is_torch_xla_available() and self.tpu
@property
def device_idx(self) -> int:
......
......@@ -121,7 +121,7 @@ from .utils import (
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torchaudio_available,
is_training_run_on_sagemaker,
is_vision_available,
......
......@@ -72,7 +72,7 @@ if TYPE_CHECKING and _has_neptune:
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
# Integration functions:
......@@ -752,7 +752,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")
......
......@@ -14,11 +14,11 @@
from torch.utils.data import DataLoader
from ..utils import is_torch_tpu_available
from ..utils import is_torch_xla_available
def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl
assert isinstance(
......
......@@ -84,7 +84,7 @@ from .utils import (
is_remote_url,
is_safetensors_available,
is_torch_sdpa_available,
is_torch_tpu_available,
is_torch_xla_available,
logging,
replace_return_docstrings,
strtobool,
......@@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
# NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
......
......@@ -19,7 +19,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from .utils import is_torch_tpu_available, logging
from .utils import is_torch_xla_available, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
......@@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
......
......@@ -115,7 +115,7 @@ from .utils import (
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
......@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case)
def require_torch_tpu(test_case):
def require_torch_xla(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
Decorator marking a test that requires TorchXLA (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
def require_torch_neuroncore(test_case):
......
......@@ -149,7 +149,7 @@ from .utils import (
is_torch_compile_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
logging,
strtobool,
)
......@@ -170,7 +170,7 @@ if is_apex_available():
if is_datasets_available():
import datasets
if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
......@@ -508,7 +508,7 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
if is_torch_tpu_available() and self.optimizer is not None:
if is_torch_xla_available() and self.optimizer is not None:
for param in self.model.parameters():
model_device = param.device
break
......@@ -856,7 +856,7 @@ class Trainer:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code
if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available():
if is_torch_xla_available():
return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
......@@ -1975,7 +1975,7 @@ class Trainer:
if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
......@@ -2027,7 +2027,7 @@ class Trainer:
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
grad_norm = _grad_norm
# Optimizer step
self.optimizer.step()
......@@ -2050,7 +2050,7 @@ class Trainer:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
......@@ -2065,7 +2065,7 @@ class Trainer:
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
if is_torch_xla_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
......@@ -2083,7 +2083,7 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_tpu_available():
if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
......@@ -2402,7 +2402,7 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available():
if is_torch_xla_available():
xm.mark_step()
logs: Dict[str, float] = {}
......@@ -2415,7 +2415,7 @@ class Trainer:
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
......@@ -2478,7 +2478,7 @@ class Trainer:
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
if is_torch_tpu_available():
if is_torch_xla_available():
xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
......@@ -2556,7 +2556,7 @@ class Trainer:
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available():
if is_torch_xla_available():
rng_states["xla"] = xm.get_rng_state()
if is_torch_npu_available():
......@@ -2575,7 +2575,7 @@ class Trainer:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_tpu_available():
if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
......@@ -2620,7 +2620,7 @@ class Trainer:
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
and not is_torch_xla_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
......@@ -2657,7 +2657,7 @@ class Trainer:
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
......@@ -2964,7 +2964,7 @@ class Trainer:
if output_dir is None:
output_dir = self.args.output_dir
if is_torch_tpu_available():
if is_torch_xla_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
......@@ -3405,7 +3405,7 @@ class Trainer:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available():
if is_torch_xla_available():
xm.mark_step()
# Update containers on host
......@@ -3529,7 +3529,7 @@ class Trainer:
"""
if tensors is None:
return
if is_torch_tpu_available():
if is_torch_xla_available():
if name is None:
name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name)
......@@ -4045,7 +4045,7 @@ class Trainer:
"""
if tensors is None:
return
if is_torch_tpu_available():
if is_torch_xla_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment