Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d8ca57d2
Unverified
Commit
d8ca57d2
authored
Oct 16, 2020
by
Stas Bekman
Committed by
GitHub
Oct 16, 2020
Browse files
fix/hide warnings (#7837)
s
parent
c6e865ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
9 deletions
+18
-9
tests/test_trainer_callback.py
tests/test_trainer_callback.py
+18
-9
No files found.
tests/test_trainer_callback.py
View file @
d8ca57d2
...
@@ -21,7 +21,7 @@ if is_torch_available():
...
@@ -21,7 +21,7 @@ if is_torch_available():
from
.test_trainer
import
RegressionDataset
,
RegressionModelConfig
,
RegressionPreTrainedModel
from
.test_trainer
import
RegressionDataset
,
RegressionModelConfig
,
RegressionPreTrainedModel
class
TestTrainerCallback
(
TrainerCallback
):
class
My
TestTrainerCallback
(
TrainerCallback
):
"A callback that registers the events that goes through."
"A callback that registers the events that goes through."
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase):
...
@@ -134,8 +134,8 @@ class TrainerCallbackTest(unittest.TestCase):
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
# Callbacks passed at init are added to the default callbacks
# Callbacks passed at init are added to the default callbacks
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
])
trainer
=
self
.
get_trainer
(
callbacks
=
[
My
TestTrainerCallback
])
expected_callbacks
.
append
(
TestTrainerCallback
)
expected_callbacks
.
append
(
My
TestTrainerCallback
)
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
...
@@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase):
...
@@ -179,35 +179,44 @@ class TrainerCallbackTest(unittest.TestCase):
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
self
.
check_callbacks_equality
(
trainer
.
callback_handler
.
callbacks
,
expected_callbacks
)
def
test_event_flow
(
self
):
def
test_event_flow
(
self
):
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
])
import
warnings
# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
warnings
.
simplefilter
(
action
=
"ignore"
,
category
=
UserWarning
)
trainer
=
self
.
get_trainer
(
callbacks
=
[
MyTestTrainerCallback
])
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
# Independent log/save/eval
# Independent log/save/eval
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
],
logging_steps
=
5
)
trainer
=
self
.
get_trainer
(
callbacks
=
[
My
TestTrainerCallback
],
logging_steps
=
5
)
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
],
save_steps
=
5
)
trainer
=
self
.
get_trainer
(
callbacks
=
[
My
TestTrainerCallback
],
save_steps
=
5
)
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
],
eval_steps
=
5
,
evaluation_strategy
=
"steps"
)
trainer
=
self
.
get_trainer
(
callbacks
=
[
My
TestTrainerCallback
],
eval_steps
=
5
,
evaluation_strategy
=
"steps"
)
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
],
evaluation_strategy
=
"epoch"
)
trainer
=
self
.
get_trainer
(
callbacks
=
[
My
TestTrainerCallback
],
evaluation_strategy
=
"epoch"
)
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
self
.
assertEqual
(
events
,
self
.
get_expected_events
(
trainer
))
# A bit of everything
# A bit of everything
trainer
=
self
.
get_trainer
(
trainer
=
self
.
get_trainer
(
callbacks
=
[
TestTrainerCallback
],
logging_steps
=
3
,
save_steps
=
10
,
eval_steps
=
5
,
evaluation_strategy
=
"steps"
callbacks
=
[
MyTestTrainerCallback
],
logging_steps
=
3
,
save_steps
=
10
,
eval_steps
=
5
,
evaluation_strategy
=
"steps"
,
)
)
trainer
.
train
()
trainer
.
train
()
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
events
=
trainer
.
callback_handler
.
callbacks
[
-
2
].
events
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment