Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
62b5622e
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e566adc09c443af843e83b239c3a18b8e7bd422d"
Unverified
Commit
62b5622e
authored
Oct 15, 2020
by
Sylvain Gugger
Committed by
GitHub
Oct 15, 2020
Browse files
Add specific notebook ProgressCalback (#7793)
parent
0911b6bd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
352 additions
and
2 deletions
+352
-2
src/transformers/file_utils.py
src/transformers/file_utils.py
+18
-0
src/transformers/trainer.py
src/transformers/trainer.py
+7
-2
src/transformers/utils/notebook.py
src/transformers/utils/notebook.py
+327
-0
No files found.
src/transformers/file_utils.py
View file @
62b5622e
...
@@ -142,6 +142,20 @@ try:
...
@@ -142,6 +142,20 @@ try:
except
(
AttributeError
,
ImportError
):
except
(
AttributeError
,
ImportError
):
_has_sklearn
=
False
_has_sklearn
=
False
try
:
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython
=
sys
.
modules
[
"IPython"
].
get_ipython
if
"IPKernelApp"
not
in
get_ipython
().
config
:
raise
ImportError
(
"console"
)
if
"VSCODE_PID"
in
os
.
environ
:
raise
ImportError
(
"vscode"
)
import
IPython
# noqa: F401
_in_notebook
=
True
except
(
ImportError
,
KeyError
):
_in_notebook
=
False
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
"transformers"
)
default_cache_path
=
os
.
path
.
join
(
torch_cache_home
,
"transformers"
)
...
@@ -203,6 +217,10 @@ def is_faiss_available():
...
@@ -203,6 +217,10 @@ def is_faiss_available():
return
_faiss_available
return
_faiss_available
def
is_in_notebook
():
return
_in_notebook
def
torch_only_method
(
fn
):
def
torch_only_method
(
fn
):
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
if
not
_torch_available
:
if
not
_torch_available
:
...
...
src/transformers/trainer.py
View file @
62b5622e
...
@@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
torch.utils.data.sampler
import
RandomSampler
,
SequentialSampler
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.data.data_collator
import
DataCollator
,
DataCollatorWithPadding
,
default_data_collator
from
.file_utils
import
WEIGHTS_NAME
,
is_datasets_available
,
is_torch_tpu_available
from
.file_utils
import
WEIGHTS_NAME
,
is_datasets_available
,
is_in_notebook
,
is_torch_tpu_available
from
.integrations
import
(
from
.integrations
import
(
default_hp_search_backend
,
default_hp_search_backend
,
is_comet_available
,
is_comet_available
,
...
@@ -89,7 +89,12 @@ _use_native_amp = False
...
@@ -89,7 +89,12 @@ _use_native_amp = False
_use_apex
=
False
_use_apex
=
False
DEFAULT_CALLBACKS
=
[
DefaultFlowCallback
]
DEFAULT_CALLBACKS
=
[
DefaultFlowCallback
]
DEFAULT_PROGRESS_CALLBACK
=
ProgressCallback
if
is_in_notebook
():
from
.utils.notebook
import
NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK
=
NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.6"
):
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.6"
):
...
@@ -235,7 +240,7 @@ class Trainer:
...
@@ -235,7 +240,7 @@ class Trainer:
)
)
callbacks
=
DEFAULT_CALLBACKS
if
callbacks
is
None
else
DEFAULT_CALLBACKS
+
callbacks
callbacks
=
DEFAULT_CALLBACKS
if
callbacks
is
None
else
DEFAULT_CALLBACKS
+
callbacks
self
.
callback_handler
=
CallbackHandler
(
callbacks
,
self
.
model
,
self
.
optimizer
,
self
.
lr_scheduler
)
self
.
callback_handler
=
CallbackHandler
(
callbacks
,
self
.
model
,
self
.
optimizer
,
self
.
lr_scheduler
)
self
.
add_callback
(
PrinterCallback
if
self
.
args
.
disable_tqdm
else
ProgressCallback
)
self
.
add_callback
(
PrinterCallback
if
self
.
args
.
disable_tqdm
else
DEFAULT_PROGRESS_CALLBACK
)
# Deprecated arguments
# Deprecated arguments
if
"tb_writer"
in
kwargs
:
if
"tb_writer"
in
kwargs
:
...
...
src/transformers/utils/notebook.py
0 → 100644
View file @
62b5622e
# coding=utf-8
# Copyright 2020 Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
from
typing
import
Optional
import
IPython.display
as
disp
from
..trainer_callback
import
TrainerCallback
def
format_time
(
t
):
"Format `t` (in seconds) to (h):mm:ss"
t
=
int
(
t
)
h
,
m
,
s
=
t
//
3600
,
(
t
//
60
)
%
60
,
t
%
60
return
f
"
{
h
}
:
{
m
:
02
d
}
:
{
s
:
02
d
}
"
if
h
!=
0
else
f
"
{
m
:
02
d
}
:
{
s
:
02
d
}
"
def
html_progress_bar
(
value
,
total
,
prefix
,
label
,
width
=
300
):
"Html code for a progress bar `value`/`total` with `label` on the right, `prefix` on the left."
return
f
"""
<div>
<style>
/* Turns off some styling */
progress {{
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}}
</style>
{
prefix
}
<progress value='
{
value
}
' max='
{
total
}
' style='width:
{
width
}
px; height:20px; vertical-align: middle;'></progress>
{
label
}
</div>
"""
def
text_to_html_table
(
items
):
"Put the texts in `items` in an HTML table."
html_code
=
"""<table border="1" class="dataframe">
\n
"""
html_code
+=
""" <thead>
\n
<tr style="text-align: left;">
\n
"""
for
i
in
items
[
0
]:
html_code
+=
f
" <th>
{
i
}
</th>
\n
"
html_code
+=
" </tr>
\n
</thead>
\n
<tbody>
\n
"
for
line
in
items
[
1
:]:
html_code
+=
" <tr>
\n
"
for
elt
in
line
:
elt
=
f
"
{
elt
:.
6
f
}
"
if
isinstance
(
elt
,
float
)
else
str
(
elt
)
html_code
+=
f
" <td>
{
elt
}
</td>
\n
"
html_code
+=
" </tr>
\n
"
html_code
+=
" </tbody>
\n
</table><p>"
return
html_code
class
NotebookProgressBar
:
"""
A progress par for display in a notebook.
Class attributes (overridden by derived classes)
- **warmup** (:obj:`int`) -- The number of iterations to do at the beginning while ignoring
:obj:`update_every`.
- **update_every** (:obj:`float`) -- Since calling the time takes some time, we only do it
every presumed :obj:`update_every` seconds. The progress bar uses the average time passed
up until now to guess the next value for which it will call the update.
Args:
total (:obj:`int`):
The total number of iterations to reach.
prefix (:obj:`str`, `optional`):
A prefix to add before the progress bar.
leave (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to leave the progress bar once it's completed. You can always call the
:meth:`~transformers.utils.notebook.NotebookProgressBar.close` method to make the bar disappear.
parent (:class:`~transformers.notebook.NotebookTrainingTracker`, `optional`):
A parent object (like :class:`~transformers.utils.notebook.NotebookTrainingTracker`) that spawns progress
bars and handle their display. If set, the object passed must have a :obj:`display()` method.
width (:obj:`int`, `optional`, defaults to 300):
The width (in pixels) that the bar will take.
Example::
import time
pbar = NotebookProgressBar(100)
for val in range(100):
pbar.update(val)
time.sleep(0.07)
pbar.update(100)
"""
warmup
=
5
update_every
=
0.2
def
__init__
(
self
,
total
:
int
,
prefix
:
Optional
[
str
]
=
None
,
leave
:
bool
=
True
,
parent
:
Optional
[
"NotebookTrainingTracker"
]
=
None
,
width
:
int
=
300
,
):
self
.
total
=
total
self
.
prefix
=
""
if
prefix
is
None
else
prefix
self
.
leave
=
leave
self
.
parent
=
parent
self
.
width
=
width
self
.
last_value
=
None
self
.
comment
=
None
self
.
output
=
None
def
update
(
self
,
value
:
int
,
force_update
:
bool
=
False
,
comment
:
str
=
None
):
"""
The main method to update the progress bar to :obj:`value`.
Args:
value (:obj:`int`):
The value to use. Must be between 0 and :obj:`total`.
force_update (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force and update of the internal state and display (by default, the bar will wait for
:obj:`value` to reach the value it predicted corresponds to a time of more than the :obj:`update_every`
attribute since the last update to avoid adding boilerplate).
comment (:obj:`str`, `optional`):
A comment to add on the left of the progress bar.
"""
self
.
value
=
value
if
comment
is
not
None
:
self
.
comment
=
comment
if
self
.
last_value
is
None
:
self
.
start_time
=
self
.
last_time
=
time
.
time
()
self
.
start_value
=
self
.
last_value
=
value
self
.
elapsed_time
=
self
.
predicted_remaining
=
None
self
.
first_calls
=
self
.
warmup
self
.
wait_for
=
1
self
.
update_bar
(
value
)
elif
value
<=
self
.
last_value
:
return
elif
force_update
or
self
.
first_calls
>
0
or
value
>=
min
(
self
.
last_value
+
self
.
wait_for
,
self
.
total
):
if
self
.
first_calls
>
0
:
self
.
first_calls
-=
1
current_time
=
time
.
time
()
self
.
elapsed_time
=
current_time
-
self
.
start_time
self
.
average_time_per_item
=
self
.
elapsed_time
/
(
value
-
self
.
start_value
)
if
value
>=
self
.
total
:
value
=
self
.
total
self
.
predicted_remaining
=
None
if
not
self
.
leave
:
self
.
close
()
else
:
self
.
predicted_remaining
=
self
.
average_time_per_item
*
(
self
.
total
-
value
)
self
.
update_bar
(
value
)
self
.
last_value
=
value
self
.
last_time
=
current_time
self
.
wait_for
=
max
(
int
(
self
.
update_every
/
self
.
average_time_per_item
),
1
)
def
update_bar
(
self
,
value
,
comment
=
None
):
spaced_value
=
" "
*
(
len
(
str
(
self
.
total
))
-
len
(
str
(
value
)))
+
str
(
value
)
if
self
.
elapsed_time
is
None
:
self
.
label
=
f
"[
{
spaced_value
}
/
{
self
.
total
}
: < :"
elif
self
.
predicted_remaining
is
None
:
self
.
label
=
f
"[
{
spaced_value
}
/
{
self
.
total
}
{
format_time
(
self
.
elapsed_time
)
}
"
else
:
self
.
label
=
f
"[
{
spaced_value
}
/
{
self
.
total
}
{
format_time
(
self
.
elapsed_time
)
}
<
{
format_time
(
self
.
predicted_remaining
)
}
"
self
.
label
+=
f
",
{
1
/
self
.
average_time_per_item
:.
2
f
}
it/s"
self
.
label
+=
"]"
if
self
.
comment
is
None
or
len
(
self
.
comment
)
==
0
else
f
",
{
self
.
comment
}
]"
self
.
display
()
def
display
(
self
):
self
.
html_code
=
html_progress_bar
(
self
.
value
,
self
.
total
,
self
.
prefix
,
self
.
label
,
self
.
width
)
if
self
.
parent
is
not
None
:
# If this is a child bar, the parent will take care of the display.
self
.
parent
.
display
()
return
if
self
.
output
is
None
:
self
.
output
=
disp
.
display
(
disp
.
HTML
(
self
.
html_code
),
display_id
=
True
)
else
:
self
.
output
.
update
(
disp
.
HTML
(
self
.
html_code
))
def
close
(
self
):
"Closes the progress bar."
if
self
.
parent
is
None
and
self
.
output
is
not
None
:
self
.
output
.
update
(
disp
.
HTML
(
""
))
class
NotebookTrainingTracker
(
NotebookProgressBar
):
"""
An object tracking the updates of an ongoing training with progress bars and a nice table reporting metrics.
Args:
num_steps (:obj:`int`): The number of steps during training.
column_names (:obj:`List[str]`, `optional`):
The list of column names for the metrics table (will be infered from the first call to
:meth:`~transformers.utils.notebook.NotebookTrainingTracker.write_line` if not set).
"""
def
__init__
(
self
,
num_steps
,
column_names
=
None
):
super
().
__init__
(
num_steps
)
self
.
inner_table
=
None
if
column_names
is
None
else
[
column_names
]
self
.
child_bar
=
None
def
display
(
self
):
self
.
html_code
=
html_progress_bar
(
self
.
value
,
self
.
total
,
self
.
prefix
,
self
.
label
,
self
.
width
)
if
self
.
inner_table
is
not
None
:
self
.
html_code
+=
text_to_html_table
(
self
.
inner_table
)
if
self
.
child_bar
is
not
None
:
self
.
html_code
+=
self
.
child_bar
.
html_code
if
self
.
output
is
None
:
self
.
output
=
disp
.
display
(
disp
.
HTML
(
self
.
html_code
),
display_id
=
True
)
else
:
self
.
output
.
update
(
disp
.
HTML
(
self
.
html_code
))
def
write_line
(
self
,
values
):
"""
Write the values in the inner table.
Args:
values (:obj:`Dict[str, float]`): The values to display.
"""
if
self
.
inner_table
is
None
:
self
.
inner_table
=
[
list
(
values
.
keys
()),
list
(
values
.
values
())]
else
:
columns
=
self
.
inner_table
[
0
]
if
len
(
self
.
inner_table
)
==
1
:
# We give a chance to update the column names at the first iteration
for
key
in
values
.
keys
():
if
key
not
in
columns
:
columns
.
append
(
key
)
self
.
inner_table
[
0
]
=
columns
self
.
inner_table
.
append
([
values
[
c
]
for
c
in
columns
])
def
add_child
(
self
,
total
,
prefix
=
None
,
width
=
300
):
"""
Add a child progress bar disaplyed under the table of metrics. The child progress bar is returned (so it can
be easily updated).
Args:
total (:obj:`int`): The number of iterations for the child progress bar.
prefix (:obj:`str`, `optional`): A prefix to write on the left of the progress bar.
width (:obj:`int`, `optional`, defaults to 300): The width (in pixels) of the progress bar.
"""
self
.
child_bar
=
NotebookProgressBar
(
total
,
prefix
=
prefix
,
parent
=
self
,
width
=
width
)
return
self
.
child_bar
def
remove_child
(
self
):
"""
Closes the child progress bar.
"""
self
.
child_bar
=
None
self
.
display
()
class
NotebookProgressCallback
(
TrainerCallback
):
"""
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation, optimized for
Jupyter Notebooks or Google colab.
"""
def
__init__
(
self
):
self
.
training_tracker
=
None
self
.
prediction_bar
=
None
def
on_train_begin
(
self
,
args
,
state
,
control
,
**
kwargs
):
self
.
first_column
=
"Epoch"
if
args
.
max_steps
<=
0
else
"Step"
self
.
training_loss
=
0
self
.
last_log
=
0
column_names
=
[
self
.
first_column
]
+
[
"Training Loss"
,
"Validation Loss"
]
self
.
training_tracker
=
NotebookTrainingTracker
(
state
.
max_steps
,
column_names
)
def
on_step_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
epoch
=
int
(
state
.
epoch
)
if
int
(
state
.
epoch
)
==
state
.
epoch
else
f
"
{
state
.
epoch
:.
2
f
}
"
self
.
training_tracker
.
update
(
state
.
global_step
+
1
,
comment
=
f
"Epoch
{
epoch
}
/
{
state
.
num_train_epochs
}
"
)
def
on_prediction_step
(
self
,
args
,
state
,
control
,
eval_dataloader
=
None
,
**
kwargs
):
if
self
.
prediction_bar
is
None
:
if
self
.
training_tracker
is
not
None
:
self
.
prediction_bar
=
self
.
training_tracker
.
add_child
(
len
(
eval_dataloader
))
else
:
self
.
prediction_bar
=
NotebookProgressBar
(
len
(
eval_dataloader
))
self
.
prediction_bar
.
update
(
1
)
else
:
self
.
prediction_bar
.
update
(
self
.
prediction_bar
.
value
+
1
)
def
on_evaluate
(
self
,
args
,
state
,
control
,
metrics
=
None
,
**
kwargs
):
if
self
.
training_tracker
is
not
None
:
values
=
{
"Training Loss"
:
"No log"
}
for
log
in
reversed
(
state
.
log_history
):
if
"loss"
in
log
:
values
[
"Training Loss"
]
=
log
[
"loss"
]
break
if
self
.
first_column
==
"Epoch"
:
values
[
"Epoch"
]
=
int
(
state
.
epoch
)
else
:
values
[
"Step"
]
=
state
.
global_step
values
[
"Validation Loss"
]
=
metrics
[
"eval_loss"
]
_
=
metrics
.
pop
(
"total_flos"
,
None
)
_
=
metrics
.
pop
(
"epoch"
,
None
)
for
k
,
v
in
metrics
.
items
():
if
k
==
"eval_loss"
:
values
[
"Validation Loss"
]
=
v
else
:
splits
=
k
.
split
(
"_"
)
name
=
" "
.
join
([
part
.
capitalize
()
for
part
in
splits
[
1
:]])
values
[
name
]
=
v
self
.
training_tracker
.
write_line
(
values
)
self
.
training_tracker
.
remove_child
()
self
.
prediction_bar
=
None
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
self
.
training_tracker
.
update
(
state
.
global_step
,
comment
=
f
"Epoch
{
int
(
state
.
epoch
)
}
/
{
state
.
num_train_epochs
}
"
,
force_update
=
True
)
self
.
training_tracker
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment