"...composable_kernel_onnxruntime.git" did not exist on "cc8df39e780fadf259ae0b822fa8403fa214d4ba"
Unverified Commit bdcc4b78 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix seq2seq example test (#7518)

* Fix seq2seq example test

* Fix bad copy-paste

* Also save the state
parent 29baa8fa
...@@ -276,6 +276,7 @@ def main(): ...@@ -276,6 +276,7 @@ def main():
# For convenience, we also re-save the tokenizer to the same directory, # For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =) # so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
# Evaluation # Evaluation
......
...@@ -4,11 +4,10 @@ import tempfile ...@@ -4,11 +4,10 @@ import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import slow
from transformers.trainer_utils import set_seed from transformers.trainer_utils import TrainerState, set_seed
from .finetune_trainer import main from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY from .test_seq2seq_examples import MBART_TINY
from .utils import load_json
set_seed(42) set_seed(42)
...@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" ...@@ -17,7 +16,7 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer(): def test_finetune_trainer():
output_dir = run_trainer(1, "12", MBART_TINY, 1) output_dir = run_trainer(1, "12", MBART_TINY, 1)
logs = load_json(os.path.join(output_dir, "log_history.json")) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats assert "eval_bleu" in first_step_stats
...@@ -30,7 +29,7 @@ def test_finetune_trainer_slow(): ...@@ -30,7 +29,7 @@ def test_finetune_trainer_slow():
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3) output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
# Check metrics # Check metrics
logs = load_json(os.path.join(output_dir, "log_history.json")) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1] last_step_stats = eval_metrics[-1]
......
...@@ -601,6 +601,7 @@ class Trainer: ...@@ -601,6 +601,7 @@ class Trainer:
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_master(): if self.is_world_master():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment