"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3c682ea15cf50636360545ba88a325868d194b0d"
Unverified Commit ebf80e2e authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Tpu trainer (#4146)



* wip

* wip

* a last wip

* Better logging when using TPUs

* Correct argument name

* Tests

* fix

* Metrics in evaluation

* Update src/transformers/training_args.py

* [tpu] Use launcher script instead

* [tpu] lots of tweaks

* Fix formatting
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 026097b9
...@@ -202,5 +202,10 @@ def main(): ...@@ -202,5 +202,10 @@ def main():
return results return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
"""
A simple launcher script for TPU training
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
::
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)
"""
import importlib
import os
import sys
from argparse import REMAINDER, ArgumentParser
import torch_xla.distributed.xla_multiprocessing as xmp
def trim_suffix(s: str, suffix: str):
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description=(
"PyTorch TPU distributed training launch "
"helper utility that will spawn up "
"multiple distributed processes"
)
)
# Optional arguments for the launch helper
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
# positional
parser.add_argument(
"training_script",
type=str,
help=(
"The full module name to the single TPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
),
)
# rest from the training program
parser.add_argument("training_script_args", nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# Import training_script as a module.
mod_name = trim_suffix(os.path.basename(args.training_script), ".py")
mod = importlib.import_module(mod_name)
# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
if __name__ == "__main__":
main()
...@@ -21,7 +21,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator ...@@ -21,7 +21,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
from .training_args import TrainingArguments from .training_args import TrainingArguments, is_tpu_available
try: try:
...@@ -36,6 +36,11 @@ def is_apex_available(): ...@@ -36,6 +36,11 @@ def is_apex_available():
return _has_apex return _has_apex
if is_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
...@@ -88,6 +93,12 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -88,6 +93,12 @@ def torch_distributed_zero_first(local_rank: int):
torch.distributed.barrier() torch.distributed.barrier()
def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
class Trainer: class Trainer:
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, Trainer is a simple but feature-complete training and eval loop for PyTorch,
...@@ -146,41 +157,73 @@ class Trainer: ...@@ -146,41 +157,73 @@ class Trainer:
) )
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
if self.args.local_rank in [-1, 0]: if self.is_local_master():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
if is_tpu_available():
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None: if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.") raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = ( if is_tpu_available():
RandomSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) train_sampler = get_tpu_sampler(self.train_dataset)
) else:
return DataLoader( train_sampler = (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
data_loader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size=self.args.train_batch_size, batch_size=self.args.train_batch_size,
sampler=train_sampler, sampler=train_sampler,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if eval_dataset is None and self.eval_dataset is None: if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
return DataLoader(
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
data_loader = DataLoader(
eval_dataset if eval_dataset is not None else self.eval_dataset, eval_dataset if eval_dataset is not None else self.eval_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
shuffle=False, shuffle=False,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
return DataLoader( sampler = get_tpu_sampler(test_dataset) if is_tpu_available() else None
data_loader = DataLoader(
test_dataset, test_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
shuffle=False, shuffle=False,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
if is_tpu_available():
data_loader = pl.ParallelLoader(data_loader, [self.args.device]).per_device_loader(self.args.device)
return data_loader
def get_optimizers( def get_optimizers(
self, num_training_steps: int self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
...@@ -222,7 +265,6 @@ class Trainer: ...@@ -222,7 +265,6 @@ class Trainer:
If present, we will try reloading the optimizer/scheduler states from there. If present, we will try reloading the optimizer/scheduler states from there.
""" """
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
if self.args.max_steps > 0: if self.args.max_steps > 0:
t_total = self.args.max_steps t_total = self.args.max_steps
num_train_epochs = ( num_train_epochs = (
...@@ -271,16 +313,21 @@ class Trainer: ...@@ -271,16 +313,21 @@ class Trainer:
self._setup_wandb() self._setup_wandb()
# Train! # Train!
if is_tpu_available():
num_examples = len(train_dataloader._loader._loader.dataset)
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
num_examples = len(train_dataloader.dataset)
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num examples = %d", num_examples)
logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
logger.info( logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -309,10 +356,10 @@ class Trainer: ...@@ -309,10 +356,10 @@ class Trainer:
logging_loss = 0.0 logging_loss = 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange( train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
) )
for epoch in train_iterator: for epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
...@@ -332,12 +379,16 @@ class Trainer: ...@@ -332,12 +379,16 @@ class Trainer:
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
optimizer.step() if is_tpu_available():
xm.optimizer_step(optimizer)
else:
optimizer.step()
scheduler.step() scheduler.step()
model.zero_grad() model.zero_grad()
global_step += 1 global_step += 1
if self.args.local_rank in [-1, 0]: if self.is_local_master():
if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (
global_step == 1 and self.args.logging_first_step global_step == 1 and self.args.logging_first_step
): ):
...@@ -371,6 +422,7 @@ class Trainer: ...@@ -371,6 +422,7 @@ class Trainer:
assert model is self.model assert model is self.model
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
self.save_model(output_dir) self.save_model(output_dir)
self._rotate_checkpoints() self._rotate_checkpoints()
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
...@@ -383,6 +435,9 @@ class Trainer: ...@@ -383,6 +435,9 @@ class Trainer:
if self.args.max_steps > 0 and global_step > self.args.max_steps: if self.args.max_steps > 0 and global_step > self.args.max_steps:
train_iterator.close() train_iterator.close()
break break
if self.args.tpu_metrics_debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
if self.tb_writer: if self.tb_writer:
self.tb_writer.close() self.tb_writer.close()
...@@ -413,12 +468,21 @@ class Trainer: ...@@ -413,12 +468,21 @@ class Trainer:
return loss.item() return loss.item()
def is_local_master(self) -> bool:
if is_tpu_available():
return xm.is_master_ordinal(local=True)
else:
return self.args.local_rank in [-1, 0]
def is_world_master(self) -> bool: def is_world_master(self) -> bool:
""" """
This will be True only in one process, even in distributed mode, This will be True only in one process, even in distributed mode,
even when training on multiple machines. even when training on multiple machines.
""" """
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 if is_tpu_available():
return xm.is_master_ordinal(local=False)
else:
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
def save_model(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None):
""" """
...@@ -495,6 +559,11 @@ class Trainer: ...@@ -495,6 +559,11 @@ class Trainer:
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self._prediction_loop(eval_dataloader, description="Evaluation") output = self._prediction_loop(eval_dataloader, description="Evaluation")
if self.args.tpu_metrics_debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
return output.metrics return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput: def predict(self, test_dataset: Dataset) -> PredictionOutput:
...@@ -558,6 +627,11 @@ class Trainer: ...@@ -558,6 +627,11 @@ class Trainer:
else: else:
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
if is_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
preds = xm.mesh_reduce("eval_preds", preds, np.concatenate)
label_ids = xm.mesh_reduce("eval_out_label_ids", label_ids, np.concatenate)
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else: else:
......
...@@ -11,6 +11,19 @@ if is_torch_available(): ...@@ -11,6 +11,19 @@ if is_torch_available():
import torch import torch
try:
import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -77,7 +90,7 @@ class TrainingArguments: ...@@ -77,7 +90,7 @@ class TrainingArguments:
) )
}, },
) )
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"}) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "random seed for initialization"}) seed: int = field(default=42, metadata={"help": "random seed for initialization"})
fp16: bool = field( fp16: bool = field(
...@@ -95,6 +108,11 @@ class TrainingArguments: ...@@ -95,6 +108,11 @@ class TrainingArguments:
) )
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
tpu_num_cores: Optional[int] = field(
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
)
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
return self.per_gpu_train_batch_size * max(1, self.n_gpu) return self.per_gpu_train_batch_size * max(1, self.n_gpu)
...@@ -110,6 +128,9 @@ class TrainingArguments: ...@@ -110,6 +128,9 @@ class TrainingArguments:
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 n_gpu = 0
elif is_tpu_available():
device = xm.xla_device()
n_gpu = 0
elif self.local_rank == -1: elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel. # if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment