Unverified Commit 5e7fe8b5 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Distributed eval: SequentialDistributedSampler + gather all results (#4243)

* Distributed eval: SequentialDistributedSampler + gather all results

* For consistency only write to disk from world_master

Close https://github.com/huggingface/transformers/issues/4272

* Working distributed eval

* Hook into scripts

* Fix #3721 again

* TPU.mesh_reduce: stay in tensor space

Thanks @jysohn23

* Just a small comment

* whitespace

* torch.hub: pip install packaging

* Add test scenarii
parent 4c068936
......@@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses packaging
- name: Torch hub list
run: |
......
......@@ -251,7 +251,7 @@ def main():
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")
eval_output = trainer.evaluate()
......@@ -260,11 +260,12 @@ def main():
result = {"perplexity": perplexity}
output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
results.update(result)
......
......@@ -202,19 +202,20 @@ def main():
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
results.update(result)
return results
......
......@@ -166,7 +166,7 @@ def main():
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")
# Loop to handle MNLI double evaluation (matched, mis-matched)
......@@ -181,11 +181,12 @@ def main():
output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
......
......@@ -235,22 +235,23 @@ def main():
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
# Predict
if training_args.do_predict and training_args.local_rank in [-1, 0]:
if training_args.do_predict:
test_dataset = NerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
......@@ -265,26 +266,30 @@ def main():
preds_list, _ = align_predictions(predictions, label_ids)
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
# Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
if trainer.is_world_master():
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
return results
......
import json
import logging
import math
import os
import random
import re
......@@ -15,7 +16,7 @@ from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator
......@@ -90,7 +91,7 @@ def set_seed(seed: int):
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for the first one (locally) to do something.
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
......@@ -99,6 +100,50 @@ def torch_distributed_zero_first(local_rank: int):
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
......@@ -156,7 +201,7 @@ class Trainer:
self.optimizers = optimizers
if tb_writer is not None:
self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
elif is_tensorboard_available() and self.is_world_master():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
logger.warning(
......@@ -171,7 +216,7 @@ class Trainer:
)
set_seed(self.args.seed)
# Create output directory if needed
if self.is_local_master():
if self.is_world_master():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_tpu_available():
# Set an xla_device flag on the model's config.
......@@ -208,13 +253,19 @@ class Trainer:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
if is_tpu_available():
sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(eval_dataset)
else:
sampler = SequentialSampler(eval_dataset)
data_loader = DataLoader(
eval_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch,
)
......@@ -225,13 +276,19 @@ class Trainer:
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval.
sampler = get_tpu_sampler(test_dataset) if is_tpu_available() else None
if is_tpu_available():
sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(test_dataset)
else:
sampler = SequentialSampler(test_dataset)
data_loader = DataLoader(
test_dataset,
sampler=sampler,
batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch,
)
......@@ -405,6 +462,9 @@ class Trainer:
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
)
for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
for step, inputs in enumerate(epoch_iterator):
......@@ -435,27 +495,25 @@ class Trainer:
self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
if self.is_local_master():
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# maintaining backward compatibility.
# could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss
self._log(logs)
if self.args.evaluate_during_training:
self.evaluate()
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss = tr_loss
self._log(logs)
if self.args.evaluate_during_training:
self.evaluate()
if self.is_world_master():
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
......@@ -548,7 +606,7 @@ class Trainer:
Saving best-practices: if you use default names for the model,
you can reload it using from_pretrained().
Will only save from the master process.
Will only save from the world_master process (unless in TPUs).
"""
if is_tpu_available():
......@@ -667,12 +725,15 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
model = self.model
model.to(self.args.device)
# multi-gpu eval
if self.args.n_gpu > 1 and not isinstance(self.model, torch.nn.DataParallel):
model = torch.nn.DataParallel(self.model)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
else:
model = self.model
model.to(self.args.device)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if is_tpu_available():
batch_size = dataloader._loader._loader.batch_size
......@@ -682,8 +743,8 @@ class Trainer:
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: np.ndarray = None
label_ids: np.ndarray = None
preds: torch.Tensor = None
label_ids: torch.Tensor = None
model.eval()
for inputs in tqdm(dataloader, desc=description):
......@@ -702,19 +763,33 @@ class Trainer:
if not prediction_loss_only:
if preds is None:
preds = logits.detach().cpu().numpy()
preds = logits.detach()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
preds = torch.cat((preds, logits.detach()), dim=0)
if inputs.get("labels") is not None:
if label_ids is None:
label_ids = inputs["labels"].detach().cpu().numpy()
label_ids = inputs["labels"].detach()
else:
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
if is_tpu_available() and preds is not None and label_ids is not None:
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
preds = xm.mesh_reduce("eval_preds", preds, np.concatenate)
label_ids = xm.mesh_reduce("eval_out_label_ids", label_ids, np.concatenate)
if preds is not None:
preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = preds.cpu().numpy()
if label_ids is not None:
label_ids = label_ids.cpu().numpy()
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
......@@ -729,3 +804,15 @@ class Trainer:
metrics[f"eval_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
# This test is meant to be run in torch.distributed,
# on a machine with multiple GPUs, in the following way:
#
# python -m torch.distributed.launch --nproc_per_node 2 ./tests/test_trainer_distributed.py
#
# Replace 2 with the number of GPUs you have.
#
# You can also run it as a standalone file to test identical behavior in nn.DataParallel:
# python ./tests/test_trainer_distributed.py
# and in single-GPU mode:
# CUDA_VISIBLE_DEVICES=0 python ./tests/test_trainer_distributed.py
#
import logging
import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
logger = logging.getLogger(__name__)
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
return i
class DummyDataCollator(DataCollator):
def collate_batch(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
# Add some (unused) params otherwise DDP will complain.
self.fc = nn.Linear(120, 80)
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses(sys.argv + ["--output_dir", "./examples"])[0]
logging.basicConfig(level=logging.INFO)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
training_args.local_rank != -1,
)
# Essentially, what we want to verify in the distributed case is
# that we get all samples back, in the right order.
# (this is crucial for prediction for instance)
for dataset_length in [101, 40, 7]:
dataset = DummyDataset(dataset_length)
def compute_metrics(p: EvalPrediction) -> Dict:
sequential = list(range(len(dataset)))
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
return {"success": success}
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
logger.error(p.metrics)
exit(1)
logger.info("🔥 All distributed tests successful")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment