Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2ce3ddab
Unverified
Commit
2ce3ddab
authored
Oct 15, 2020
by
Sylvain Gugger
Committed by
GitHub
Oct 15, 2020
Browse files
Small fixes to NotebookProgressCallback (#7813)
parent
6f45dd2f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
5 deletions
+24
-5
src/transformers/file_utils.py
src/transformers/file_utils.py
+1
-1
src/transformers/utils/notebook.py
src/transformers/utils/notebook.py
+23
-4
No files found.
src/transformers/file_utils.py
View file @
2ce3ddab
...
@@ -153,7 +153,7 @@ try:
...
@@ -153,7 +153,7 @@ try:
import
IPython
# noqa: F401
import
IPython
# noqa: F401
_in_notebook
=
True
_in_notebook
=
True
except
:
# noqa: E722
except
(
AttributeError
,
ImportError
,
KeyError
):
_in_notebook
=
False
_in_notebook
=
False
...
...
src/transformers/utils/notebook.py
View file @
2ce3ddab
...
@@ -19,6 +19,7 @@ from typing import Optional
...
@@ -19,6 +19,7 @@ from typing import Optional
import
IPython.display
as
disp
import
IPython.display
as
disp
from
..trainer_callback
import
TrainerCallback
from
..trainer_callback
import
TrainerCallback
from
..trainer_utils
import
EvaluationStrategy
def
format_time
(
t
):
def
format_time
(
t
):
...
@@ -146,7 +147,7 @@ class NotebookProgressBar:
...
@@ -146,7 +147,7 @@ class NotebookProgressBar:
self
.
first_calls
=
self
.
warmup
self
.
first_calls
=
self
.
warmup
self
.
wait_for
=
1
self
.
wait_for
=
1
self
.
update_bar
(
value
)
self
.
update_bar
(
value
)
elif
value
<=
self
.
last_value
:
elif
value
<=
self
.
last_value
and
not
force_update
:
return
return
elif
force_update
or
self
.
first_calls
>
0
or
value
>=
min
(
self
.
last_value
+
self
.
wait_for
,
self
.
total
):
elif
force_update
or
self
.
first_calls
>
0
or
value
>=
min
(
self
.
last_value
+
self
.
wait_for
,
self
.
total
):
if
self
.
first_calls
>
0
:
if
self
.
first_calls
>
0
:
...
@@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
...
@@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
training_tracker
=
None
self
.
training_tracker
=
None
self
.
prediction_bar
=
None
self
.
prediction_bar
=
None
self
.
_force_next_update
=
False
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
self
.
first_column
=
"Epoch"
if
args
.
max_steps
<=
0
else
"Step"
self
.
first_column
=
"Epoch"
if
args
.
evaluation_strategy
==
EvaluationStrategy
.
EPOCH
else
"Step"
self
.
training_loss
=
0
self
.
training_loss
=
0
self
.
last_log
=
0
self
.
last_log
=
0
column_names
=
[
self
.
first_column
]
+
[
"Training Loss"
,
"Validation Loss"
]
column_names
=
[
self
.
first_column
]
+
[
"Training Loss"
]
if
args
.
evaluation_strategy
!=
EvaluationStrategy
.
NO
:
column_names
.
append
(
"Validation Loss"
)
self
.
training_tracker
=
NotebookTrainingTracker
(
state
.
max_steps
,
column_names
)
self
.
training_tracker
=
NotebookTrainingTracker
(
state
.
max_steps
,
column_names
)
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
epoch
=
int
(
state
.
epoch
)
if
int
(
state
.
epoch
)
==
state
.
epoch
else
f
"
{
state
.
epoch
:.
2
f
}
"
epoch
=
int
(
state
.
epoch
)
if
int
(
state
.
epoch
)
==
state
.
epoch
else
f
"
{
state
.
epoch
:.
2
f
}
"
self
.
training_tracker
.
update
(
state
.
global_step
+
1
,
comment
=
f
"Epoch
{
epoch
}
/
{
state
.
num_train_epochs
}
"
)
self
.
training_tracker
.
update
(
state
.
global_step
+
1
,
comment
=
f
"Epoch
{
epoch
}
/
{
state
.
num_train_epochs
}
"
,
force_update
=
self
.
_force_next_update
,
)
self
.
_force_next_update
=
False
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
if
self
.
prediction_bar
is
None
:
if
self
.
prediction_bar
is
None
:
...
@@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
...
@@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
else
:
else
:
self
.
prediction_bar
.
update
(
self
.
prediction_bar
.
value
+
1
)
self
.
prediction_bar
.
update
(
self
.
prediction_bar
.
value
+
1
)
def
on_log
(
self
,
args
,
state
,
control
,
logs
=
None
,
**
kwargs
):
# Only for when there is no evaluation
if
args
.
evaluation_strategy
==
EvaluationStrategy
.
NO
and
"loss"
in
logs
:
values
=
{
"Training Loss"
:
logs
[
"loss"
]}
# First column is necessarily Step sine we're not in epoch eval strategy
values
[
"Step"
]
=
state
.
global_step
self
.
training_tracker
.
write_line
(
values
)
def
on_evaluate
(
self
,
args
,
state
,
control
,
metrics
=
None
,
**
kwargs
):
def
on_evaluate
(
self
,
args
,
state
,
control
,
metrics
=
None
,
**
kwargs
):
if
self
.
training_tracker
is
not
None
:
if
self
.
training_tracker
is
not
None
:
values
=
{
"Training Loss"
:
"No log"
}
values
=
{
"Training Loss"
:
"No log"
}
...
@@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
...
@@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
self
.
training_tracker
.
write_line
(
values
)
self
.
training_tracker
.
write_line
(
values
)
self
.
training_tracker
.
remove_child
()
self
.
training_tracker
.
remove_child
()
self
.
prediction_bar
=
None
self
.
prediction_bar
=
None
# Evaluation takes a long time so we should force the next update.
self
.
_force_next_update
=
True
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
self
.
training_tracker
.
update
(
self
.
training_tracker
.
update
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment