Unverified Commit 873d9bb3 authored by Yitong Huang's avatar Yitong Huang Committed by GitHub
Browse files

Make torch xla available on GPU (#29334)



* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------
Co-authored-by: default avataranw90 <ang868@gmail.com>
Co-authored-by: default avatarwangang.wa <wangang.wa@alibaba-inc.com>
parent 9a3f4d4d
...@@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP ...@@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich - `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs - `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs - `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU - `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU
Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen: Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:
......
...@@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise: ...@@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs - `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs - `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs - `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU - `require_torch_xla` - as `require_torch` plus requires at least 1 TPU
Let's depict the GPU requirements in the following table: Let's depict the GPU requirements in the following table:
......
...@@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py ...@@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。 - `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。 - `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。 - `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。 - `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
以下の表にGPUの要件を示します: 以下の表にGPUの要件を示します:
......
...@@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py ...@@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다. - `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다. - `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다. - `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다. - `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ: GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:
......
...@@ -32,7 +32,7 @@ from transformers.optimization import ( ...@@ -32,7 +32,7 @@ from transformers.optimization import (
) )
from transformers.trainer_pt_utils import get_tpu_sampler from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available from transformers.utils import is_torch_xla_available
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -135,7 +135,7 @@ class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None return None
elif is_torch_tpu_available(): elif is_torch_xla_available():
return get_tpu_sampler(self.train_dataset) return get_tpu_sampler(self.train_dataset)
else: else:
if self.args.sortish_sampler: if self.args.sortish_sampler:
......
...@@ -46,7 +46,7 @@ from transformers import ( ...@@ -46,7 +46,7 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
default_data_collator, default_data_collator,
is_torch_tpu_available, is_torch_xla_available,
set_seed, set_seed,
) )
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
...@@ -602,9 +602,9 @@ def main(): ...@@ -602,9 +602,9 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator, data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available() if training_args.do_eval and not is_torch_xla_available()
else None, else None,
) )
......
...@@ -45,7 +45,7 @@ from transformers import ( ...@@ -45,7 +45,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
is_torch_tpu_available, is_torch_xla_available,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -620,9 +620,9 @@ def main(): ...@@ -620,9 +620,9 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available() if training_args.do_eval and not is_torch_xla_available()
else None, else None,
) )
......
...@@ -21,7 +21,7 @@ import sys ...@@ -21,7 +21,7 @@ import sys
from time import time from time import time
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, require_torch_tpu from transformers.testing_utils import TestCasePlus, require_torch_xla
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout) ...@@ -44,7 +44,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@require_torch_tpu @require_torch_xla
class TorchXLAExamplesTests(TestCasePlus): class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self): def test_run_glue(self):
import xla_spawn import xla_spawn
......
...@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks ...@@ -18,11 +18,11 @@ A subclass of `Trainer` specific to Question-Answering tasks
import math import math
import time import time
from transformers import Trainer, is_torch_tpu_available from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -21,11 +21,11 @@ from typing import Dict, List, Optional ...@@ -21,11 +21,11 @@ from typing import Dict, List, Optional
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics from transformers.trainer_utils import PredictionOutput, speed_metrics
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -24,13 +24,13 @@ import quant_trainer ...@@ -24,13 +24,13 @@ import quant_trainer
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import Trainer, is_torch_tpu_available from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput from transformers.trainer_utils import PredictionOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
......
...@@ -1093,6 +1093,7 @@ _import_structure = { ...@@ -1093,6 +1093,7 @@ _import_structure = {
"is_torch_npu_available", "is_torch_npu_available",
"is_torch_tpu_available", "is_torch_tpu_available",
"is_torchvision_available", "is_torchvision_available",
"is_torch_xla_available",
"is_torch_xpu_available", "is_torch_xpu_available",
"is_vision_available", "is_vision_available",
"logging", "logging",
...@@ -5897,6 +5898,7 @@ if TYPE_CHECKING: ...@@ -5897,6 +5898,7 @@ if TYPE_CHECKING:
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchvision_available, is_torchvision_available,
is_vision_available, is_vision_available,
......
...@@ -20,7 +20,7 @@ from typing import Tuple ...@@ -20,7 +20,7 @@ from typing import Tuple
from ..utils import ( from ..utils import (
cached_property, cached_property,
is_torch_available, is_torch_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
logging, logging,
requires_backends, requires_backends,
...@@ -31,7 +31,7 @@ from .benchmark_args_utils import BenchmarkArguments ...@@ -31,7 +31,7 @@ from .benchmark_args_utils import BenchmarkArguments
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -88,7 +88,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -88,7 +88,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
if not self.cuda: if not self.cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 n_gpu = 0
elif is_torch_tpu_available(): elif is_torch_xla_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 n_gpu = 0
elif is_torch_xpu_available(): elif is_torch_xpu_available():
...@@ -101,7 +101,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments): ...@@ -101,7 +101,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
@property @property
def is_tpu(self): def is_tpu(self):
return is_torch_tpu_available() and self.tpu return is_torch_xla_available() and self.tpu
@property @property
def device_idx(self) -> int: def device_idx(self) -> int:
......
...@@ -121,7 +121,7 @@ from .utils import ( ...@@ -121,7 +121,7 @@ from .utils import (
is_torch_fx_proxy, is_torch_fx_proxy,
is_torch_mps_available, is_torch_mps_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torchaudio_available, is_torchaudio_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,
......
...@@ -72,7 +72,7 @@ if TYPE_CHECKING and _has_neptune: ...@@ -72,7 +72,7 @@ if TYPE_CHECKING and _has_neptune:
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402 from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
# Integration functions: # Integration functions:
...@@ -752,7 +752,7 @@ class WandbCallback(TrainerCallback): ...@@ -752,7 +752,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false") _watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer") self._wandb.run._label(code="transformers_trainer")
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..utils import is_torch_tpu_available from ..utils import is_torch_xla_available
def tpu_spmd_dataloader(dataloader: DataLoader): def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available(): if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
assert isinstance( assert isinstance(
......
...@@ -84,7 +84,7 @@ from .utils import ( ...@@ -84,7 +84,7 @@ from .utils import (
is_remote_url, is_remote_url,
is_safetensors_available, is_safetensors_available,
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_tpu_available, is_torch_xla_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
strtobool, strtobool,
...@@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152 # Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
return torch.bfloat16 return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
if t.dtype == torch.float: if t.dtype == torch.float:
return torch.bfloat16 return torch.bfloat16
if t.dtype == torch.double: if t.dtype == torch.double:
......
...@@ -19,7 +19,7 @@ from packaging import version ...@@ -19,7 +19,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
from .utils import is_torch_tpu_available, logging from .utils import is_torch_xla_available, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm] ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
...@@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: ...@@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
if tensor.device.type == "xla" and is_torch_tpu_available(): if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage # NOTE: xla tensors dont have storage
# use some other unique id to distinguish. # use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's # this is a XLA tensor, it must be created using torch_xla's
......
...@@ -115,7 +115,7 @@ from .utils import ( ...@@ -115,7 +115,7 @@ from .utils import (
is_torch_sdpa_available, is_torch_sdpa_available,
is_torch_tensorrt_fx_available, is_torch_tensorrt_fx_available,
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_xla_available,
is_torch_xpu_available, is_torch_xpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdynamo_available, is_torchdynamo_available,
...@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case): ...@@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case) (test_case)
def require_torch_tpu(test_case): def require_torch_xla(test_case):
""" """
Decorator marking a test that requires a TPU (in PyTorch). Decorator marking a test that requires TorchXLA (in PyTorch).
""" """
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case) return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
def require_torch_neuroncore(test_case): def require_torch_neuroncore(test_case):
......
...@@ -149,7 +149,7 @@ from .utils import ( ...@@ -149,7 +149,7 @@ from .utils import (
is_torch_compile_available, is_torch_compile_available,
is_torch_neuroncore_available, is_torch_neuroncore_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_tpu_available, is_torch_xla_available,
logging, logging,
strtobool, strtobool,
) )
...@@ -170,7 +170,7 @@ if is_apex_available(): ...@@ -170,7 +170,7 @@ if is_apex_available():
if is_datasets_available(): if is_datasets_available():
import datasets import datasets
if is_torch_tpu_available(check_device=False): if is_torch_xla_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs import torch_xla.distributed.spmd as xs
...@@ -508,7 +508,7 @@ class Trainer: ...@@ -508,7 +508,7 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument. " "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
if is_torch_tpu_available() and self.optimizer is not None: if is_torch_xla_available() and self.optimizer is not None:
for param in self.model.parameters(): for param in self.model.parameters():
model_device = param.device model_device = param.device
break break
...@@ -856,7 +856,7 @@ class Trainer: ...@@ -856,7 +856,7 @@ class Trainer:
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code # Deprecated code
if self.args.use_legacy_prediction_loop: if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available(): if is_torch_xla_available():
return SequentialDistributedSampler( return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
) )
...@@ -1975,7 +1975,7 @@ class Trainer: ...@@ -1975,7 +1975,7 @@ class Trainer:
if ( if (
args.logging_nan_inf_filter args.logging_nan_inf_filter
and not is_torch_tpu_available() and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
): ):
# if loss is nan or inf simply add the average of previous logged losses # if loss is nan or inf simply add the average of previous logged losses
...@@ -2027,7 +2027,7 @@ class Trainer: ...@@ -2027,7 +2027,7 @@ class Trainer:
if hasattr(grad_norm, "item"): if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item() grad_norm = grad_norm.item()
else: else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None grad_norm = _grad_norm
# Optimizer step # Optimizer step
self.optimizer.step() self.optimizer.step()
...@@ -2050,7 +2050,7 @@ class Trainer: ...@@ -2050,7 +2050,7 @@ class Trainer:
# PyTorch/XLA relies on the data loader to insert the mark_step for # PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually # each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here. # insert the mark_step here.
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
break break
if step < 0: if step < 0:
...@@ -2065,7 +2065,7 @@ class Trainer: ...@@ -2065,7 +2065,7 @@ class Trainer:
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available(): if is_torch_xla_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
else: else:
...@@ -2083,7 +2083,7 @@ class Trainer: ...@@ -2083,7 +2083,7 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0. # Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end") xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED: elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier() dist.barrier()
...@@ -2402,7 +2402,7 @@ class Trainer: ...@@ -2402,7 +2402,7 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
...@@ -2415,7 +2415,7 @@ class Trainer: ...@@ -2415,7 +2415,7 @@ class Trainer:
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None: if grad_norm is not None:
logs["grad_norm"] = grad_norm logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
logs["learning_rate"] = self._get_learning_rate() logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
...@@ -2478,7 +2478,7 @@ class Trainer: ...@@ -2478,7 +2478,7 @@ class Trainer:
f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted." "\nThis won't yield the same results as if the training had not been interrupted."
) )
if is_torch_tpu_available(): if is_torch_xla_available():
xm.set_rng_state(checkpoint_rng_state["xla"]) xm.set_rng_state(checkpoint_rng_state["xla"])
if is_torch_npu_available(): if is_torch_npu_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED: if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
...@@ -2556,7 +2556,7 @@ class Trainer: ...@@ -2556,7 +2556,7 @@ class Trainer:
else: else:
rng_states["cuda"] = torch.cuda.random.get_rng_state() rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available(): if is_torch_xla_available():
rng_states["xla"] = xm.get_rng_state() rng_states["xla"] = xm.get_rng_state()
if is_torch_npu_available(): if is_torch_npu_available():
...@@ -2575,7 +2575,7 @@ class Trainer: ...@@ -2575,7 +2575,7 @@ class Trainer:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
def _save_optimizer_and_scheduler(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_tpu_available(): if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
...@@ -2620,7 +2620,7 @@ class Trainer: ...@@ -2620,7 +2620,7 @@ class Trainer:
if ( if (
self.args.should_save self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available() and not is_torch_xla_available()
): ):
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
...@@ -2657,7 +2657,7 @@ class Trainer: ...@@ -2657,7 +2657,7 @@ class Trainer:
) )
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states # Load in optimizer and scheduler states
if is_torch_tpu_available(): if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device. # On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
...@@ -2964,7 +2964,7 @@ class Trainer: ...@@ -2964,7 +2964,7 @@ class Trainer:
if output_dir is None: if output_dir is None:
output_dir = self.args.output_dir output_dir = self.args.output_dir
if is_torch_tpu_available(): if is_torch_xla_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes. # Calling the state_dict needs to be done on the wrapped model and on all processes.
...@@ -3405,7 +3405,7 @@ class Trainer: ...@@ -3405,7 +3405,7 @@ class Trainer:
main_input_name = getattr(self.model, "main_input_name", "input_ids") main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available(): if is_torch_xla_available():
xm.mark_step() xm.mark_step()
# Update containers on host # Update containers on host
...@@ -3529,7 +3529,7 @@ class Trainer: ...@@ -3529,7 +3529,7 @@ class Trainer:
""" """
if tensors is None: if tensors is None:
return return
if is_torch_tpu_available(): if is_torch_xla_available():
if name is None: if name is None:
name = "nested_gather" name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
...@@ -4045,7 +4045,7 @@ class Trainer: ...@@ -4045,7 +4045,7 @@ class Trainer:
""" """
if tensors is None: if tensors is None:
return return
if is_torch_tpu_available(): if is_torch_xla_available():
tensors = nested_xla_mesh_reduce(tensors, name) tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors) tensors = smp_gather(tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment