Unverified Commit 7bff0af0 authored by Harutaka Kawamura's avatar Harutaka Kawamura Committed by GitHub
Browse files

Fix a bug for `CallbackHandler.callback_list` (#8052)



* Fix callback_list

* Add test
Signed-off-by: default avatarharupy <17039389+harupy@users.noreply.github.com>

* Fix test
Signed-off-by: default avatarharupy <17039389+harupy@users.noreply.github.com>
parent 8e28c327
...@@ -325,7 +325,7 @@ class CallbackHandler(TrainerCallback): ...@@ -325,7 +325,7 @@ class CallbackHandler(TrainerCallback):
@property @property
def callback_list(self): def callback_list(self):
return "\n".join(self.callbacks) return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_init_end", args, state, control) return self.call_event("on_init_end", args, state, control)
......
...@@ -221,3 +221,10 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -221,3 +221,10 @@ class TrainerCallbackTest(unittest.TestCase):
trainer.train() trainer.train()
events = trainer.callback_handler.callbacks[-2].events events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks
with unittest.mock.patch("transformers.trainer_callback.logger.warn") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment