Unverified Commit 7b75aa9f authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (#4223)

* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once

* Update examples/README.md

* [xla_spawn] Add `_mp_fn` to other Trainer scripts

* [TPU] Fix: eval dataloader was None
parent 274d850d
......@@ -53,4 +53,28 @@ pip install -r ./examples/requirements.txt
## Running on TPUs
Documentation to come.
When using Tensorflow, TPUs are supported out of the box as a `tf.distribute.Strategy`.
When using PyTorch, we support TPUs thanks to `pytorch/xla`. For more context and information on how to setup your TPU environment refer to Google's documentation and to the
very detailed [pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
In this repo, we provide a very simple launcher script named [xla_spawn.py](./xla_spawn.py) that lets you run our example scripts on multiple TPU cores without any boilerplate.
Just pass a `--num_cores` flag to this script, then your regular training script with its arguments (this is similar to the `torch.distributed.launch` helper for torch.distributed).
For example for `run_glue`:
```bash
python examples/xla_spawn.py --num_cores 8 \
examples/text-classification/run_glue.py
--model_name_or_path bert-base-cased \
--task_name mnli \
--data_dir ./data/glue_data/MNLI \
--output_dir ./models/tpu \
--overwrite_output_dir \
--do_train \
--do_eval \
--num_train_epochs 1 \
--save_steps 20000
```
Feedback and more use cases and benchmarks involving TPUs are welcome, please share with the community.
......@@ -404,7 +404,7 @@ def main():
logger.info("Training/evaluation parameters %s", args)
# Prepare dataset for the GLUE task
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True, local_rank=args.local_rank)
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
if args.data_subset > 0:
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
......
......@@ -280,5 +280,10 @@ def main():
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
......@@ -221,5 +221,10 @@ def main():
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
......@@ -85,10 +85,12 @@ CoLA, SST-2. The following section provides details on how to run half-precision
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
since the data processor for each task inherits from the base class DataProcessor.
## Running on TPUs
## Running on TPUs in PyTorch
You can accelerate your workloads on Google's TPUs. For information on how to setup your TPU environment refer to this
[README](https://github.com/pytorch/xla/blob/master/README.md).
**Update**: read the more up-to-date [Running on TPUs](../README.md#running-on-tpus) in the main README.md instead.
Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on how to setup your TPU environment refer to the
[pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
identical to your normal GPU + Huggingface setup.
......@@ -101,7 +103,6 @@ export GLUE_DIR=/path/to/glue
export TASK_NAME=MNLI
python run_glue_tpu.py \
--model_type bert \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--do_train \
......@@ -115,8 +116,7 @@ python run_glue_tpu.py \
--overwrite_output_dir \
--logging_steps 50 \
--save_steps 200 \
--num_cores=8 \
--only_log_master
--num_cores=8
```
### MRPC
......
......@@ -134,16 +134,8 @@ def main():
)
# Get datasets
train_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank)
if training_args.do_train
else None
)
eval_dataset = (
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
if training_args.do_eval
else None
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
def compute_metrics(p: EvalPrediction) -> Dict:
if output_mode == "classification":
......@@ -181,9 +173,7 @@ def main():
eval_datasets = [eval_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
eval_datasets.append(
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
)
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
for eval_dataset in eval_datasets:
result = trainer.evaluate(eval_dataset=eval_dataset)
......
......@@ -292,5 +292,10 @@ def main():
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
......@@ -12,17 +12,13 @@ Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/lau
import importlib
import os
import sys
from argparse import REMAINDER, ArgumentParser
from pathlib import Path
import torch_xla.distributed.xla_multiprocessing as xmp
def trim_suffix(s: str, suffix: str):
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
def parse_args():
"""
Helper function parsing the command line options
......@@ -44,7 +40,7 @@ def parse_args():
"training_script",
type=str,
help=(
"The full module name to the single TPU training "
"The full path to the single TPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
......@@ -61,7 +57,9 @@ def main():
args = parse_args()
# Import training_script as a module.
mod_name = trim_suffix(os.path.basename(args.training_script), ".py")
script_fpath = Path(args.training_script)
sys.path.append(str(script_fpath.parent.resolve()))
mod_name = script_fpath.stem
mod = importlib.import_module(mod_name)
# Patch sys.argv
......
......@@ -5,12 +5,12 @@ from dataclasses import dataclass, field
from typing import List, Optional
import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
from ...trainer import torch_distributed_zero_first
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
from ..processors.utils import InputFeatures
......@@ -63,7 +63,6 @@ class GlueDataset(Dataset):
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
evaluate=False,
local_rank=-1,
):
self.args = args
processor = glue_processors[args.task_name]()
......@@ -75,9 +74,11 @@ class GlueDataset(Dataset):
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
),
)
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
......@@ -109,7 +110,6 @@ class GlueDataset(Dataset):
label_list=label_list,
output_mode=self.output_mode,
)
if local_rank in [-1, 0]:
start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
......
......@@ -6,7 +6,7 @@ import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -195,10 +195,12 @@ class Trainer:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
data_loader = DataLoader(
eval_dataset if eval_dataset is not None else self.eval_dataset,
eval_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
shuffle=False,
......@@ -267,6 +269,16 @@ class Trainer:
# keep track of model topology and gradients
wandb.watch(self.model)
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
"""
Helper to get num of examples from a DataLoader, by accessing its Dataset.
"""
if is_tpu_available():
assert isinstance(dataloader, pl.PerDeviceLoader)
return len(dataloader._loader._loader.dataset)
else:
return len(dataloader.dataset)
def train(self, model_path: Optional[str] = None):
"""
Main training entry point.
......@@ -326,17 +338,15 @@ class Trainer:
# Train!
if is_tpu_available():
num_examples = len(train_dataloader._loader._loader.dataset)
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
num_examples = len(train_dataloader.dataset)
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", num_examples)
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
......@@ -606,9 +616,13 @@ class Trainer:
model = self.model
model.to(self.args.device)
if is_tpu_available():
batch_size = dataloader._loader._loader.batch_size
else:
batch_size = dataloader.batch_size
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", len(dataloader.dataset))
logger.info(" Batch size = %d", dataloader.batch_size)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: np.ndarray = None
label_ids: np.ndarray = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment