Unverified Commit 78387cc6 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] only save metrics.json from rank zero (#7331)

parent e53138a1
......@@ -8,6 +8,8 @@ import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils import save_json
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
......@@ -72,8 +74,15 @@ class Seq2SeqLoggingCallback(pl.Callback):
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
save_json(pl_module.metrics, pl_module.metrics_save_path)
return self._write_logs(trainer, pl_module, "test")
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module):
save_json(pl_module.metrics, pl_module.metrics_save_path)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
"""Saves the best model by validation ROUGE2 score."""
......
......@@ -30,7 +30,6 @@ from utils import (
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
......@@ -189,7 +188,7 @@ class SummarizationModule(BaseTransformer):
losses.update(generative_metrics)
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
all_metrics["step_count"] = self.step_count
self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {
"log": all_metrics,
......@@ -198,10 +197,6 @@ class SummarizationModule(BaseTransformer):
f"{prefix}_{self.val_metric}": metric_tensor,
}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment