Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ee04b698
"tests/modeling_tf_bert_test.py" did not exist on "0b22e47a4016635a750c0eb4bbc41bdfd93e14d7"
Unverified
Commit
ee04b698
authored
Feb 26, 2021
by
Stas Bekman
Committed by
GitHub
Feb 26, 2021
Browse files
[examples] better model example (#10427)
* refactors * typo
parent
a85eb616
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
20 deletions
+46
-20
examples/seq2seq/run_seq2seq.py
examples/seq2seq/run_seq2seq.py
+8
-18
src/transformers/trainer.py
src/transformers/trainer.py
+1
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+37
-1
No files found.
examples/seq2seq/run_seq2seq.py
View file @
ee04b698
...
@@ -572,7 +572,6 @@ def main():
...
@@ -572,7 +572,6 @@ def main():
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
)
)
all_metrics
=
{}
# Training
# Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
last_checkpoint
is
not
None
:
if
last_checkpoint
is
not
None
:
...
@@ -589,13 +588,10 @@ def main():
...
@@ -589,13 +588,10 @@ def main():
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
train_dataset
)
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
train_dataset
)
)
)
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
train_dataset
))
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
train_dataset
))
if
trainer
.
is_world_process_zero
():
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
all_metrics
.
update
(
metrics
)
trainer
.
save_state
()
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer
.
state
.
save_to_json
(
os
.
path
.
join
(
training_args
.
output_dir
,
"trainer_state.json"
))
# Evaluation
# Evaluation
results
=
{}
results
=
{}
...
@@ -608,10 +604,8 @@ def main():
...
@@ -608,10 +604,8 @@ def main():
max_val_samples
=
data_args
.
max_val_samples
if
data_args
.
max_val_samples
is
not
None
else
len
(
eval_dataset
)
max_val_samples
=
data_args
.
max_val_samples
if
data_args
.
max_val_samples
is
not
None
else
len
(
eval_dataset
)
metrics
[
"eval_samples"
]
=
min
(
max_val_samples
,
len
(
eval_dataset
))
metrics
[
"eval_samples"
]
=
min
(
max_val_samples
,
len
(
eval_dataset
))
if
trainer
.
is_world_process_zero
():
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
all_metrics
.
update
(
metrics
)
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
logger
.
info
(
"*** Test ***"
)
logger
.
info
(
"*** Test ***"
)
...
@@ -626,11 +620,10 @@ def main():
...
@@ -626,11 +620,10 @@ def main():
max_test_samples
=
data_args
.
max_test_samples
if
data_args
.
max_test_samples
is
not
None
else
len
(
test_dataset
)
max_test_samples
=
data_args
.
max_test_samples
if
data_args
.
max_test_samples
is
not
None
else
len
(
test_dataset
)
metrics
[
"test_samples"
]
=
min
(
max_test_samples
,
len
(
test_dataset
))
metrics
[
"test_samples"
]
=
min
(
max_test_samples
,
len
(
test_dataset
))
if
trainer
.
is_world_process_zero
():
trainer
.
log_metrics
(
"test"
,
metrics
)
trainer
.
log_metrics
(
"test"
,
metrics
)
trainer
.
save_metrics
(
"test"
,
metrics
)
trainer
.
save_metrics
(
"test"
,
metrics
)
all_metrics
.
update
(
metrics
)
if
trainer
.
is_world_process_zero
():
if
training_args
.
predict_with_generate
:
if
training_args
.
predict_with_generate
:
test_preds
=
tokenizer
.
batch_decode
(
test_preds
=
tokenizer
.
batch_decode
(
test_results
.
predictions
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
test_results
.
predictions
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
...
@@ -640,9 +633,6 @@ def main():
...
@@ -640,9 +633,6 @@ def main():
with
open
(
output_test_preds_file
,
"w"
)
as
writer
:
with
open
(
output_test_preds_file
,
"w"
)
as
writer
:
writer
.
write
(
"
\n
"
.
join
(
test_preds
))
writer
.
write
(
"
\n
"
.
join
(
test_preds
))
if
trainer
.
is_world_process_zero
():
trainer
.
save_metrics
(
"all"
,
metrics
)
return
results
return
results
...
...
src/transformers/trainer.py
View file @
ee04b698
...
@@ -231,7 +231,7 @@ class Trainer:
...
@@ -231,7 +231,7 @@ class Trainer:
"""
"""
from
.trainer_pt_utils
import
_get_learning_rate
,
log_metrics
,
metrics_format
,
save_metrics
from
.trainer_pt_utils
import
_get_learning_rate
,
log_metrics
,
metrics_format
,
save_metrics
,
save_state
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/transformers/trainer_pt_utils.py
View file @
ee04b698
...
@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
...
@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
"""
"""
Log metrics in a specially formatted way
Log metrics in a specially formatted way
Under distributed environment this is done only for a process with rank 0.
Args:
Args:
split (:obj:`str`):
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
"""
if
not
self
.
is_world_process_zero
():
return
logger
.
info
(
f
"*****
{
split
}
metrics *****"
)
logger
.
info
(
f
"*****
{
split
}
metrics *****"
)
metrics_formatted
=
self
.
metrics_format
(
metrics
)
metrics_formatted
=
self
.
metrics_format
(
metrics
)
...
@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
...
@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
logger
.
info
(
f
"
{
key
:
<
{
k_width
}}
=
{
metrics_formatted
[
key
]:
>
{
v_width
}}
"
)
logger
.
info
(
f
"
{
key
:
<
{
k_width
}}
=
{
metrics_formatted
[
key
]:
>
{
v_width
}}
"
)
def
save_metrics
(
self
,
split
,
metrics
):
def
save_metrics
(
self
,
split
,
metrics
,
combined
=
True
):
"""
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Save metrics into a json file for that split, e.g. ``train_results.json``.
Under distributed environment this is done only for a process with rank 0.
Args:
Args:
split (:obj:`str`):
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
The metrics returned from train/evaluate/predict
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
Creates combined metrics by updating ``all_results.json`` with metrics of this call
"""
"""
if
not
self
.
is_world_process_zero
():
return
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
split
}
_results.json"
)
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
split
}
_results.json"
)
with
open
(
path
,
"w"
)
as
f
:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
metrics
,
f
,
indent
=
4
,
sort_keys
=
True
)
json
.
dump
(
metrics
,
f
,
indent
=
4
,
sort_keys
=
True
)
if
combined
:
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"all_results.json"
)
if
os
.
path
.
exists
(
path
):
with
open
(
path
,
"r"
)
as
f
:
all_metrics
=
json
.
load
(
f
)
else
:
all_metrics
=
{}
all_metrics
.
update
(
metrics
)
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
all_metrics
,
f
,
indent
=
4
,
sort_keys
=
True
)
def
save_state
(
self
):
"""
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
Under distributed environment this is done only for a process with rank 0.
"""
if
not
self
.
is_world_process_zero
():
return
path
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"trainer_state.json"
)
self
.
state
.
save_to_json
(
path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment