Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1198ba8f
Unverified
Commit
1198ba8f
authored
Dec 18, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 18, 2020
Browse files
Add timing inside Trainer (#9196)
* Add timing inside Trainer * Fix tests * Add n_objs for train * Sort logs
parent
9a25c5bd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
49 deletions
+76
-49
examples/seq2seq/finetune_trainer.py
examples/seq2seq/finetune_trainer.py
+7
-34
examples/test_examples.py
examples/test_examples.py
+1
-3
src/transformers/trainer.py
src/transformers/trainer.py
+14
-3
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+22
-0
src/transformers/training_args.py
src/transformers/training_args.py
+12
-4
tests/test_trainer.py
tests/test_trainer.py
+20
-5
No files found.
examples/seq2seq/finetune_trainer.py
View file @
1198ba8f
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
import
logging
import
logging
import
os
import
os
import
sys
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
...
@@ -120,30 +119,6 @@ class DataTrainingArguments:
...
@@ -120,30 +119,6 @@ class DataTrainingArguments:
)
)
def
speed_metrics
(
split
,
start_time
,
num_samples
):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime
=
time
.
time
()
-
start_time
result
=
{}
samples_per_second
=
1
/
(
runtime
/
num_samples
)
result
[
f
"
{
split
}
_samples_per_second"
]
=
round
(
samples_per_second
,
3
)
result
[
f
"
{
split
}
_runtime"
]
=
round
(
runtime
,
4
)
result
[
f
"
{
split
}
_n_ojbs"
]
=
num_samples
return
result
def
handle_metrics
(
split
,
metrics
,
output_dir
):
def
handle_metrics
(
split
,
metrics
,
output_dir
):
"""
"""
Log and save metrics
Log and save metrics
...
@@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir):
...
@@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir):
"""
"""
logger
.
info
(
f
"*****
{
split
}
metrics *****"
)
logger
.
info
(
f
"*****
{
split
}
metrics *****"
)
for
key
,
value
in
metrics
.
item
s
():
for
key
in
sorted
(
metrics
.
key
s
()
)
:
logger
.
info
(
f
"
{
key
}
=
{
value
}
"
)
logger
.
info
(
f
"
{
key
}
=
{
metrics
[
key
]
}
"
)
save_json
(
metrics
,
os
.
path
.
join
(
output_dir
,
f
"
{
split
}
_results.json"
))
save_json
(
metrics
,
os
.
path
.
join
(
output_dir
,
f
"
{
split
}
_results.json"
))
...
@@ -311,11 +286,11 @@ def main():
...
@@ -311,11 +286,11 @@ def main():
if
training_args
.
do_train
:
if
training_args
.
do_train
:
logger
.
info
(
"*** Train ***"
)
logger
.
info
(
"*** Train ***"
)
start_time
=
time
.
time
()
train_result
=
trainer
.
train
(
trainer
.
train
(
model_path
=
model_args
.
model_name_or_path
if
os
.
path
.
isdir
(
model_args
.
model_name_or_path
)
else
None
model_path
=
model_args
.
model_name_or_path
if
os
.
path
.
isdir
(
model_args
.
model_name_or_path
)
else
None
)
)
metrics
=
speed_metrics
(
"train"
,
start_time
,
data_args
.
n_train
)
metrics
=
train_result
.
metrics
metrics
[
"train_n_objs"
]
=
data_args
.
n_train
trainer
.
save_model
()
# this also saves the tokenizer
trainer
.
save_model
()
# this also saves the tokenizer
...
@@ -334,9 +309,8 @@ def main():
...
@@ -334,9 +309,8 @@ def main():
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
logger
.
info
(
"*** Evaluate ***"
)
logger
.
info
(
"*** Evaluate ***"
)
start_time
=
time
.
time
()
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"val"
)
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"val"
)
metrics
.
update
(
speed_metrics
(
"val"
,
start_time
,
data_args
.
n_val
))
metrics
[
"val_n_objs"
]
=
data_args
.
n_val
metrics
[
"val_loss"
]
=
round
(
metrics
[
"val_loss"
],
4
)
metrics
[
"val_loss"
]
=
round
(
metrics
[
"val_loss"
],
4
)
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
...
@@ -347,10 +321,9 @@ def main():
...
@@ -347,10 +321,9 @@ def main():
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
logger
.
info
(
"*** Predict ***"
)
logger
.
info
(
"*** Predict ***"
)
start_time
=
time
.
time
()
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
,
metric_key_prefix
=
"test"
)
test_output
=
trainer
.
predict
(
test_dataset
=
test_dataset
,
metric_key_prefix
=
"test"
)
metrics
=
test_output
.
metrics
metrics
=
test_output
.
metrics
metrics
.
update
(
speed_metrics
(
"test"
,
start_time
,
data_args
.
n_test
))
metrics
[
"test_n_objs"
]
=
data_args
.
n_test
if
trainer
.
is_world_process_zero
():
if
trainer
.
is_world_process_zero
():
metrics
[
"test_loss"
]
=
round
(
metrics
[
"test_loss"
],
4
)
metrics
[
"test_loss"
]
=
round
(
metrics
[
"test_loss"
],
4
)
...
...
examples/test_examples.py
View file @
1198ba8f
...
@@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
result
=
run_glue
.
main
()
result
=
run_glue
.
main
()
del
result
[
"eval_loss"
]
self
.
assertGreaterEqual
(
result
[
"eval_accuracy"
],
0.75
)
for
value
in
result
.
values
():
self
.
assertGreaterEqual
(
value
,
0.75
)
@
require_torch_non_multi_gpu_but_fix_me
@
require_torch_non_multi_gpu_but_fix_me
def
test_run_clm
(
self
):
def
test_run_clm
(
self
):
...
...
src/transformers/trainer.py
View file @
1198ba8f
...
@@ -22,6 +22,7 @@ import math
...
@@ -22,6 +22,7 @@ import math
import
os
import
os
import
re
import
re
import
shutil
import
shutil
import
time
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
...
@@ -89,6 +90,7 @@ from .trainer_utils import (
...
@@ -89,6 +90,7 @@ from .trainer_utils import (
default_compute_objective
,
default_compute_objective
,
default_hp_space
,
default_hp_space
,
set_seed
,
set_seed
,
speed_metrics
,
)
)
from
.training_args
import
TrainingArguments
from
.training_args
import
TrainingArguments
from
.utils
import
logging
from
.utils
import
logging
...
@@ -707,6 +709,7 @@ class Trainer:
...
@@ -707,6 +709,7 @@ class Trainer:
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
self
.
state
.
epoch
=
0
self
.
state
.
epoch
=
0
start_time
=
time
.
time
()
epochs_trained
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
steps_trained_in_current_epoch
=
0
...
@@ -870,15 +873,17 @@ class Trainer:
...
@@ -870,15 +873,17 @@ class Trainer:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_NAME
))
state_dict
=
torch
.
load
(
os
.
path
.
join
(
self
.
state
.
best_model_checkpoint
,
WEIGHTS_NAME
))
self
.
model
.
load_state_dict
(
state_dict
)
self
.
model
.
load_state_dict
(
state_dict
)
metrics
=
speed_metrics
(
"train"
,
start_time
,
self
.
state
.
max_steps
)
if
self
.
_total_flos
is
not
None
:
if
self
.
_total_flos
is
not
None
:
self
.
store_flos
()
self
.
store_flos
()
self
.
log
({
"total_flos"
:
self
.
state
.
total_flos
})
metrics
[
"total_flos"
]
=
self
.
state
.
total_flos
self
.
log
(
metrics
)
self
.
control
=
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
self
.
control
=
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
# add remaining tr_loss
# add remaining tr_loss
self
.
_total_loss_scalar
+=
tr_loss
.
item
()
self
.
_total_loss_scalar
+=
tr_loss
.
item
()
return
TrainOutput
(
self
.
state
.
global_step
,
self
.
_total_loss_scalar
/
self
.
state
.
global_step
)
return
TrainOutput
(
self
.
state
.
global_step
,
self
.
_total_loss_scalar
/
self
.
state
.
global_step
,
metrics
)
def
_maybe_log_save_evaluate
(
self
,
tr_loss
,
model
,
trial
,
epoch
):
def
_maybe_log_save_evaluate
(
self
,
tr_loss
,
model
,
trial
,
epoch
):
if
self
.
control
.
should_log
:
if
self
.
control
.
should_log
:
...
@@ -1317,6 +1322,7 @@ class Trainer:
...
@@ -1317,6 +1322,7 @@ class Trainer:
raise
ValueError
(
"eval_dataset must implement __len__"
)
raise
ValueError
(
"eval_dataset must implement __len__"
)
eval_dataloader
=
self
.
get_eval_dataloader
(
eval_dataset
)
eval_dataloader
=
self
.
get_eval_dataloader
(
eval_dataset
)
start_time
=
time
.
time
()
output
=
self
.
prediction_loop
(
output
=
self
.
prediction_loop
(
eval_dataloader
,
eval_dataloader
,
...
@@ -1328,6 +1334,8 @@ class Trainer:
...
@@ -1328,6 +1334,8 @@ class Trainer:
metric_key_prefix
=
metric_key_prefix
,
metric_key_prefix
=
metric_key_prefix
,
)
)
n_samples
=
len
(
eval_dataset
if
eval_dataset
is
not
None
else
self
.
eval_dataset
)
output
.
metrics
.
update
(
speed_metrics
(
metric_key_prefix
,
start_time
,
n_samples
))
self
.
log
(
output
.
metrics
)
self
.
log
(
output
.
metrics
)
if
self
.
args
.
tpu_metrics_debug
or
self
.
args
.
debug
:
if
self
.
args
.
tpu_metrics_debug
or
self
.
args
.
debug
:
...
@@ -1374,10 +1382,13 @@ class Trainer:
...
@@ -1374,10 +1382,13 @@ class Trainer:
raise
ValueError
(
"test_dataset must implement __len__"
)
raise
ValueError
(
"test_dataset must implement __len__"
)
test_dataloader
=
self
.
get_test_dataloader
(
test_dataset
)
test_dataloader
=
self
.
get_test_dataloader
(
test_dataset
)
start_time
=
time
.
time
()
return
self
.
prediction_loop
(
output
=
self
.
prediction_loop
(
test_dataloader
,
description
=
"Prediction"
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
test_dataloader
,
description
=
"Prediction"
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
)
)
output
.
metrics
.
update
(
speed_metrics
(
metric_key_prefix
,
start_time
,
len
(
test_dataset
)))
return
output
def
prediction_loop
(
def
prediction_loop
(
self
,
self
,
...
...
src/transformers/trainer_utils.py
View file @
1198ba8f
...
@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
...
@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
import
copy
import
copy
import
random
import
random
import
time
from
typing
import
Any
,
Dict
,
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
...
@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
class
TrainOutput
(
NamedTuple
):
class
TrainOutput
(
NamedTuple
):
global_step
:
int
global_step
:
int
training_loss
:
float
training_loss
:
float
metrics
:
Dict
[
str
,
float
]
PREFIX_CHECKPOINT_DIR
=
"checkpoint"
PREFIX_CHECKPOINT_DIR
=
"checkpoint"
...
@@ -179,3 +181,23 @@ def total_processes_number(local_rank):
...
@@ -179,3 +181,23 @@ def total_processes_number(local_rank):
return
torch
.
distributed
.
get_world_size
()
return
torch
.
distributed
.
get_world_size
()
return
1
return
1
def
speed_metrics
(
split
,
start_time
,
num_samples
=
None
):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime
=
time
.
time
()
-
start_time
result
=
{
f
"
{
split
}
_runtime"
:
round
(
runtime
,
4
)}
if
num_samples
is
not
None
:
samples_per_second
=
1
/
(
runtime
/
num_samples
)
result
[
f
"
{
split
}
_samples_per_second"
]
=
round
(
samples_per_second
,
3
)
return
result
src/transformers/training_args.py
View file @
1198ba8f
...
@@ -12,10 +12,9 @@
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
dataclasses
import
json
import
json
import
os
import
os
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
...
@@ -411,7 +410,16 @@ class TrainingArguments:
...
@@ -411,7 +410,16 @@ class TrainingArguments:
self
.
run_name
=
self
.
output_dir
self
.
run_name
=
self
.
output_dir
if
is_torch_available
()
and
self
.
device
.
type
!=
"cuda"
and
self
.
fp16
:
if
is_torch_available
()
and
self
.
device
.
type
!=
"cuda"
and
self
.
fp16
:
raise
ValueError
(
"AMP (`--fp16`) can only be used on CUDA devices."
)
raise
ValueError
(
"Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices."
)
def
__repr__
(
self
):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
self_as_dict
=
asdict
(
self
)
del
self_as_dict
[
"per_gpu_train_batch_size"
]
del
self_as_dict
[
"per_gpu_eval_batch_size"
]
attrs_as_str
=
[
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
self_as_dict
.
items
()]
return
f
"
{
self
.
__class__
.
__name__
}
(
{
', '
.
join
(
attrs_as_str
)
}
)"
@
property
@
property
def
train_batch_size
(
self
)
->
int
:
def
train_batch_size
(
self
)
->
int
:
...
@@ -523,7 +531,7 @@ class TrainingArguments:
...
@@ -523,7 +531,7 @@ class TrainingArguments:
"""
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
"""
d
=
dataclasses
.
asdict
(
self
)
d
=
asdict
(
self
)
for
k
,
v
in
d
.
items
():
for
k
,
v
in
d
.
items
():
if
isinstance
(
v
,
Enum
):
if
isinstance
(
v
,
Enum
):
d
[
k
]
=
v
.
value
d
[
k
]
=
v
.
value
...
...
tests/test_trainer.py
View file @
1198ba8f
...
@@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics
=
trainer
.
evaluate
()
metrics
=
trainer
.
evaluate
()
self
.
assertEqual
(
metrics
[
metric
],
best_value
)
self
.
assertEqual
(
metrics
[
metric
],
best_value
)
def
check_trainer_state_are_the_same
(
self
,
trainer_state
,
trainer_state1
):
# We'll pop things so operate on copies.
state
=
trainer_state
.
copy
()
state1
=
trainer_state1
.
copy
()
# Log history main contain different logs for the time metrics (after resuming a training).
log_history
=
state
.
pop
(
"log_history"
,
None
)
log_history1
=
state1
.
pop
(
"log_history"
,
None
)
self
.
assertEqual
(
state
,
state1
)
for
log
,
log1
in
zip
(
log_history
,
log_history1
):
_
=
log
.
pop
(
"train_runtime"
,
None
)
_
=
log1
.
pop
(
"train_runtime"
,
None
)
_
=
log
.
pop
(
"train_samples_per_second"
,
None
)
_
=
log1
.
pop
(
"train_samples_per_second"
,
None
)
self
.
assertEqual
(
log
,
log1
)
def
test_trainer_works_with_dict
(
self
):
def
test_trainer_works_with_dict
(
self
):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
# anything.
...
@@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
...
@@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
# With a regular model that is not a PreTrainedModel
# With a regular model that is not a PreTrainedModel
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
...
@@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
# Now check with a later checkpoint that it also works when we span over one epoch
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
checkpoint
=
os
.
path
.
join
(
tmpdir
,
"checkpoint-15"
)
...
@@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
def
test_resume_training_with_gradient_accumulation
(
self
):
def
test_resume_training_with_gradient_accumulation
(
self
):
if
torch
.
cuda
.
device_count
()
>
2
:
if
torch
.
cuda
.
device_count
()
>
2
:
...
@@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
state1
=
dataclasses
.
asdict
(
trainer
.
state
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
a
,
a1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
b
,
b1
)
self
.
assertEqual
(
state
,
state1
)
self
.
check_trainer_state_are_the_same
(
state
,
state1
)
def
test_load_best_model_at_end
(
self
):
def
test_load_best_model_at_end
(
self
):
total
=
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
)
total
=
int
(
self
.
n_epochs
*
64
/
self
.
batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment