"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0ecfd17f49e81cade5ddd7321480a5a3492dd36d"
Unverified Commit 5e7fe8b5 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Distributed eval: SequentialDistributedSampler + gather all results (#4243)

* Distributed eval: SequentialDistributedSampler + gather all results

* For consistency only write to disk from world_master

Close https://github.com/huggingface/transformers/issues/4272

* Working distributed eval

* Hook into scripts

* Fix #3721 again

* TPU.mesh_reduce: stay in tensor space

Thanks @jysohn23

* Just a small comment

* whitespace

* torch.hub: pip install packaging

* Add test scenarii
parent 4c068936
...@@ -21,7 +21,7 @@ jobs: ...@@ -21,7 +21,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install torch pip install torch
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses packaging
- name: Torch hub list - name: Torch hub list
run: | run: |
......
...@@ -251,7 +251,7 @@ def main(): ...@@ -251,7 +251,7 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
eval_output = trainer.evaluate() eval_output = trainer.evaluate()
...@@ -260,11 +260,12 @@ def main(): ...@@ -260,11 +260,12 @@ def main():
result = {"perplexity": perplexity} result = {"perplexity": perplexity}
output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
with open(output_eval_file, "w") as writer: if trainer.is_world_master():
logger.info("***** Eval results *****") with open(output_eval_file, "w") as writer:
for key in sorted(result.keys()): logger.info("***** Eval results *****")
logger.info(" %s = %s", key, str(result[key])) for key in sorted(result.keys()):
writer.write("%s = %s\n" % (key, str(result[key]))) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
results.update(result) results.update(result)
......
...@@ -202,19 +202,20 @@ def main(): ...@@ -202,19 +202,20 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
result = trainer.evaluate() result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: if trainer.is_world_master():
logger.info("***** Eval results *****") with open(output_eval_file, "w") as writer:
for key, value in result.items(): logger.info("***** Eval results *****")
logger.info(" %s = %s", key, value) for key, value in result.items():
writer.write("%s = %s\n" % (key, value)) logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result) results.update(result)
return results return results
......
...@@ -166,7 +166,7 @@ def main(): ...@@ -166,7 +166,7 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
# Loop to handle MNLI double evaluation (matched, mis-matched) # Loop to handle MNLI double evaluation (matched, mis-matched)
...@@ -181,11 +181,12 @@ def main(): ...@@ -181,11 +181,12 @@ def main():
output_eval_file = os.path.join( output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
) )
with open(output_eval_file, "w") as writer: if trainer.is_world_master():
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) with open(output_eval_file, "w") as writer:
for key, value in result.items(): logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
logger.info(" %s = %s", key, value) for key, value in result.items():
writer.write("%s = %s\n" % (key, value)) logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result) results.update(result)
......
...@@ -235,22 +235,23 @@ def main(): ...@@ -235,22 +235,23 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
result = trainer.evaluate() result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: if trainer.is_world_master():
logger.info("***** Eval results *****") with open(output_eval_file, "w") as writer:
for key, value in result.items(): logger.info("***** Eval results *****")
logger.info(" %s = %s", key, value) for key, value in result.items():
writer.write("%s = %s\n" % (key, value)) logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result) results.update(result)
# Predict # Predict
if training_args.do_predict and training_args.local_rank in [-1, 0]: if training_args.do_predict:
test_dataset = NerDataset( test_dataset = NerDataset(
data_dir=data_args.data_dir, data_dir=data_args.data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -265,26 +266,30 @@ def main(): ...@@ -265,26 +266,30 @@ def main():
preds_list, _ = align_predictions(predictions, label_ids) preds_list, _ = align_predictions(predictions, label_ids)
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer: if trainer.is_world_master():
for key, value in metrics.items(): with open(output_test_results_file, "w") as writer:
logger.info(" %s = %s", key, value) for key, value in metrics.items():
writer.write("%s = %s\n" % (key, value)) logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
# Save predictions # Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer: if trainer.is_world_master():
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: with open(output_test_predictions_file, "w") as writer:
example_id = 0 with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
for line in f: example_id = 0
if line.startswith("-DOCSTART-") or line == "" or line == "\n": for line in f:
writer.write(line) if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if not preds_list[example_id]: writer.write(line)
example_id += 1 if not preds_list[example_id]:
elif preds_list[example_id]: example_id += 1
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" elif preds_list[example_id]:
writer.write(output_line) output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
else: writer.write(output_line)
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
return results return results
......
import json import json
import logging import logging
import math
import os import os
import random import random
import re import re
...@@ -15,7 +16,7 @@ from torch import nn ...@@ -15,7 +16,7 @@ from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DefaultDataCollator from .data.data_collator import DataCollator, DefaultDataCollator
...@@ -90,7 +91,7 @@ def set_seed(seed: int): ...@@ -90,7 +91,7 @@ def set_seed(seed: int):
@contextmanager @contextmanager
def torch_distributed_zero_first(local_rank: int): def torch_distributed_zero_first(local_rank: int):
""" """
Decorator to make all processes in distributed training wait for the first one (locally) to do something. Decorator to make all processes in distributed training wait for each local_master to do something.
""" """
if local_rank not in [-1, 0]: if local_rank not in [-1, 0]:
torch.distributed.barrier() torch.distributed.barrier()
...@@ -99,6 +100,50 @@ def torch_distributed_zero_first(local_rank: int): ...@@ -99,6 +100,50 @@ def torch_distributed_zero_first(local_rank: int):
torch.distributed.barrier() torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: Dataset): def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1: if xm.xrt_world_size() <= 1:
return RandomSampler(dataset) return RandomSampler(dataset)
...@@ -156,7 +201,7 @@ class Trainer: ...@@ -156,7 +201,7 @@ class Trainer:
self.optimizers = optimizers self.optimizers = optimizers
if tb_writer is not None: if tb_writer is not None:
self.tb_writer = tb_writer self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]: elif is_tensorboard_available() and self.is_world_master():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available(): if not is_tensorboard_available():
logger.warning( logger.warning(
...@@ -171,7 +216,7 @@ class Trainer: ...@@ -171,7 +216,7 @@ class Trainer:
) )
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
if self.is_local_master(): if self.is_world_master():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
if is_tpu_available(): if is_tpu_available():
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
...@@ -208,13 +253,19 @@ class Trainer: ...@@ -208,13 +253,19 @@ class Trainer:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None if is_tpu_available():
sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(eval_dataset)
else:
sampler = SequentialSampler(eval_dataset)
data_loader = DataLoader( data_loader = DataLoader(
eval_dataset, eval_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
...@@ -225,13 +276,19 @@ class Trainer: ...@@ -225,13 +276,19 @@ class Trainer:
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval. # We use the same batch_size as for eval.
sampler = get_tpu_sampler(test_dataset) if is_tpu_available() else None if is_tpu_available():
sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
elif self.args.local_rank != -1:
sampler = SequentialDistributedSampler(test_dataset)
else:
sampler = SequentialSampler(test_dataset)
data_loader = DataLoader( data_loader = DataLoader(
test_dataset, test_dataset,
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
shuffle=False,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
) )
...@@ -405,6 +462,9 @@ class Trainer: ...@@ -405,6 +462,9 @@ class Trainer:
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
) )
for epoch in train_iterator: for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
...@@ -435,27 +495,25 @@ class Trainer: ...@@ -435,27 +495,25 @@ class Trainer:
self.global_step += 1 self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator) self.epoch = epoch + (step + 1) / len(epoch_iterator)
if self.is_local_master(): if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( self.global_step == 1 and self.args.logging_first_step
self.global_step == 1 and self.args.logging_first_step ):
): logs: Dict[str, float] = {}
logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers
# maintaining backward compatibility. logs["learning_rate"] = (
# could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0 scheduler.get_last_lr()[0]
logs["learning_rate"] = ( if version.parse(torch.__version__) >= version.parse("1.4")
scheduler.get_last_lr()[0] else scheduler.get_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4") )
else scheduler.get_lr()[0] logging_loss = tr_loss
)
self._log(logs)
logging_loss = tr_loss
if self.args.evaluate_during_training:
self._log(logs) self.evaluate()
if self.args.evaluate_during_training: if self.is_world_master():
self.evaluate()
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save. # to the model we want to save.
...@@ -548,7 +606,7 @@ class Trainer: ...@@ -548,7 +606,7 @@ class Trainer:
Saving best-practices: if you use default names for the model, Saving best-practices: if you use default names for the model,
you can reload it using from_pretrained(). you can reload it using from_pretrained().
Will only save from the master process. Will only save from the world_master process (unless in TPUs).
""" """
if is_tpu_available(): if is_tpu_available():
...@@ -667,12 +725,15 @@ class Trainer: ...@@ -667,12 +725,15 @@ class Trainer:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
model = self.model
model.to(self.args.device)
# multi-gpu eval # multi-gpu eval
if self.args.n_gpu > 1 and not isinstance(self.model, torch.nn.DataParallel): if self.args.n_gpu > 1:
model = torch.nn.DataParallel(self.model) model = torch.nn.DataParallel(model)
else: else:
model = self.model model = self.model
model.to(self.args.device) # Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if is_tpu_available(): if is_tpu_available():
batch_size = dataloader._loader._loader.batch_size batch_size = dataloader._loader._loader.batch_size
...@@ -682,8 +743,8 @@ class Trainer: ...@@ -682,8 +743,8 @@ class Trainer:
logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Batch size = %d", batch_size) logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = [] eval_losses: List[float] = []
preds: np.ndarray = None preds: torch.Tensor = None
label_ids: np.ndarray = None label_ids: torch.Tensor = None
model.eval() model.eval()
for inputs in tqdm(dataloader, desc=description): for inputs in tqdm(dataloader, desc=description):
...@@ -702,19 +763,33 @@ class Trainer: ...@@ -702,19 +763,33 @@ class Trainer:
if not prediction_loss_only: if not prediction_loss_only:
if preds is None: if preds is None:
preds = logits.detach().cpu().numpy() preds = logits.detach()
else: else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) preds = torch.cat((preds, logits.detach()), dim=0)
if inputs.get("labels") is not None: if inputs.get("labels") is not None:
if label_ids is None: if label_ids is None:
label_ids = inputs["labels"].detach().cpu().numpy() label_ids = inputs["labels"].detach()
else: else:
label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
if is_tpu_available() and preds is not None and label_ids is not None: if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) if preds is not None:
label_ids = xm.mesh_reduce("eval_out_label_ids", label_ids, np.concatenate) preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
if label_ids is not None:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = preds.cpu().numpy()
if label_ids is not None:
label_ids = label_ids.cpu().numpy()
if self.compute_metrics is not None and preds is not None and label_ids is not None: if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
...@@ -729,3 +804,15 @@ class Trainer: ...@@ -729,3 +804,15 @@ class Trainer:
metrics[f"eval_{key}"] = metrics.pop(key) metrics[f"eval_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
# This test is meant to be run in torch.distributed,
# on a machine with multiple GPUs, in the following way:
#
# python -m torch.distributed.launch --nproc_per_node 2 ./tests/test_trainer_distributed.py
#
# Replace 2 with the number of GPUs you have.
#
# You can also run it as a standalone file to test identical behavior in nn.DataParallel:
# python ./tests/test_trainer_distributed.py
# and in single-GPU mode:
# CUDA_VISIBLE_DEVICES=0 python ./tests/test_trainer_distributed.py
#
import logging
import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
logger = logging.getLogger(__name__)
if is_torch_available():
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
return i
class DummyDataCollator(DataCollator):
def collate_batch(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
# Add some (unused) params otherwise DDP will complain.
self.fc = nn.Linear(120, 80)
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses(sys.argv + ["--output_dir", "./examples"])[0]
logging.basicConfig(level=logging.INFO)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
training_args.local_rank != -1,
)
# Essentially, what we want to verify in the distributed case is
# that we get all samples back, in the right order.
# (this is crucial for prediction for instance)
for dataset_length in [101, 40, 7]:
dataset = DummyDataset(dataset_length)
def compute_metrics(p: EvalPrediction) -> Dict:
sequential = list(range(len(dataset)))
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
return {"success": success}
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["eval_success"] is not True:
logger.error(p.metrics)
exit(1)
logger.info("🔥 All distributed tests successful")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment