Unverified Commit 080e14b2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Modify `warnings` in a `with` block to avoid flaky tests (#31893)



* fix

* [test_all] check before merge

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent ec03d97b
...@@ -218,52 +218,53 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -218,52 +218,53 @@ class TrainerCallbackTest(unittest.TestCase):
import warnings import warnings
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested # XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
warnings.simplefilter(action="ignore", category=UserWarning) with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=UserWarning)
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
trainer.train() trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
events = trainer.callback_handler.callbacks[-2].events trainer.train()
self.assertEqual(events, self.get_expected_events(trainer)) events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# Independent log/save/eval
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) # Independent log/save/eval
trainer.train() trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
events = trainer.callback_handler.callbacks[-2].events trainer.train()
self.assertEqual(events, self.get_expected_events(trainer)) events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
trainer.train() trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
events = trainer.callback_handler.callbacks[-2].events trainer.train()
self.assertEqual(events, self.get_expected_events(trainer)) events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
trainer.train() trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
events = trainer.callback_handler.callbacks[-2].events trainer.train()
self.assertEqual(events, self.get_expected_events(trainer)) events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
trainer.train() trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
events = trainer.callback_handler.callbacks[-2].events trainer.train()
self.assertEqual(events, self.get_expected_events(trainer)) events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# A bit of everything
trainer = self.get_trainer( # A bit of everything
callbacks=[MyTestTrainerCallback],
logging_steps=3,
save_steps=10,
eval_steps=5,
eval_strategy="steps",
)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
trainer = self.get_trainer( trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], callbacks=[MyTestTrainerCallback],
logging_steps=3,
save_steps=10,
eval_steps=5,
eval_strategy="steps",
) )
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
def test_stateful_callbacks(self): def test_stateful_callbacks(self):
# Use something with non-defaults # Use something with non-defaults
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment