Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
562b6369
Unverified
Commit
562b6369
authored
Jul 30, 2020
by
Sylvain Gugger
Committed by
GitHub
Jul 30, 2020
Browse files
Tf trainer cleanup (#6143)
* Clean up TFTrainer * Add import * Fix conflicts
parent
c127d055
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
23 deletions
+69
-23
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+69
-23
No files found.
src/transformers/trainer_tf.py
View file @
562b6369
...
...
@@ -5,6 +5,7 @@ import logging
import
math
import
os
import
sys
import
warnings
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -104,7 +105,7 @@ class TFTrainer:
self
.
tb_writer
=
tf
.
summary
.
create_file_writer
(
self
.
args
.
logging_dir
)
if
is_wandb_available
():
self
.
_
setup_wandb
()
self
.
setup_wandb
()
elif
os
.
environ
.
get
(
"WANDB_DISABLED"
)
!=
"true"
:
logger
.
info
(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
...
...
@@ -116,6 +117,8 @@ class TFTrainer:
def
get_train_tfdataset
(
self
)
->
tf
.
data
.
Dataset
:
"""
Returns the training :class:`~tf.data.Dataset`.
Subclass and override this method if you want to inject some custom behavior.
"""
if
self
.
train_dataset
is
None
:
raise
ValueError
(
"Trainer: training requires a train_dataset."
)
...
...
@@ -142,6 +145,8 @@ class TFTrainer:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
If provided, will override `self.eval_dataset`.
Subclass and override this method if you want to inject some custom behavior.
"""
if
eval_dataset
is
None
and
self
.
eval_dataset
is
None
:
raise
ValueError
(
"Trainer: evaluation requires an eval_dataset."
)
...
...
@@ -168,6 +173,8 @@ class TFTrainer:
Args:
test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
Subclass and override this method if you want to inject some custom behavior.
"""
num_examples
=
tf
.
data
.
experimental
.
cardinality
(
test_dataset
).
numpy
()
...
...
@@ -185,14 +192,12 @@ class TFTrainer:
return
self
.
args
.
strategy
.
experimental_distribute_dataset
(
ds
),
steps
,
num_examples
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
,
)
->
Tuple
[
tf
.
keras
.
optimizers
.
Optimizer
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
]:
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
TFTrainer's init through :obj:`optimizers`, or override this method
in a subclass
.
TFTrainer's init through :obj:`optimizers`, or
subclass and
override this method.
"""
if
not
self
.
optimizer
and
not
self
.
lr_scheduler
:
self
.
optimizer
,
self
.
lr_scheduler
=
create_optimizer
(
...
...
@@ -205,12 +210,12 @@ class TFTrainer:
weight_decay_rate
=
self
.
args
.
weight_decay
,
)
def
_
setup_wandb
(
self
):
def
setup_wandb
(
self
):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can override this method to customize the setup if needed.
Find more information
at https://docs.wandb.com/huggingface
You can also override the following environment variables:
One can
subclass and
override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__.
You can also override the following environment variables:
Environment:
WANDB_PROJECT:
...
...
@@ -218,10 +223,17 @@ class TFTrainer:
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if
hasattr
(
self
,
"_setup_wandb"
):
warnings
.
warn
(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass."
,
FutureWarning
,
)
return
self
.
_setup_wandb
()
logger
.
info
(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
wandb
.
init
(
project
=
os
.
getenv
(
"WANDB_PROJECT"
,
"huggingface"
),
config
=
vars
(
self
.
args
))
def
_
prediction_loop
(
def
prediction_loop
(
self
,
dataset
:
tf
.
data
.
Dataset
,
steps
:
int
,
...
...
@@ -230,10 +242,19 @@ class TFTrainer:
prediction_loss_only
:
Optional
[
bool
]
=
None
,
)
->
PredictionOutput
:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
Prediction/evaluation loop, shared by :func:`~transformers.TFTrainer.evaluate` and
:func:`~transformers.TFTrainer.predict`.
Works both with or without labels.
"""
if
hasattr
(
self
,
"_prediction_loop"
):
warnings
.
warn
(
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass."
,
FutureWarning
,
)
return
self
.
_prediction_loop
(
dataset
,
steps
,
num_examples
,
description
,
prediction_loss_only
=
prediction_loss_only
)
prediction_loss_only
=
prediction_loss_only
if
prediction_loss_only
is
not
None
else
self
.
prediction_loss_only
...
...
@@ -250,7 +271,7 @@ class TFTrainer:
self
.
_past
=
None
for
step
,
batch
in
enumerate
(
dataset
):
logits
=
self
.
distributed_
test
_steps
(
batch
)
logits
=
self
.
distributed_
prediction
_steps
(
batch
)
_
,
labels
=
batch
if
not
prediction_loss_only
:
...
...
@@ -303,7 +324,13 @@ class TFTrainer:
return
PredictionOutput
(
predictions
=
preds
,
label_ids
=
label_ids
,
metrics
=
metrics
)
def
_log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
if
hasattr
(
self
,
"_log"
):
warnings
.
warn
(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass."
,
FutureWarning
,
)
return
self
.
_log
(
logs
)
logs
[
"epoch"
]
=
self
.
epoch_logging
if
self
.
tb_writer
:
...
...
@@ -335,24 +362,28 @@ class TFTrainer:
eval_ds
,
steps
,
num_examples
=
self
.
get_eval_tfdataset
(
eval_dataset
)
output
=
self
.
_prediction_loop
(
eval_ds
,
steps
,
num_examples
,
description
=
"Evaluation"
)
logs
=
{
**
output
.
metrics
}
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_
log
(
logs
)
self
.
log
(
logs
)
return
output
.
metrics
def
test_step
(
self
,
features
,
labels
):
per_example_loss
,
logits
=
self
.
_run_model
(
features
,
labels
,
False
)
def
prediction_step
(
self
,
features
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""
Compute the prediction on features and update the loss with labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss
,
logits
=
self
.
run_model
(
features
,
labels
,
False
)
self
.
eval_loss
.
update_state
(
per_example_loss
)
return
logits
@
tf
.
function
def
distributed_
test
_steps
(
self
,
batch
):
logits
=
self
.
args
.
strategy
.
run
(
self
.
test
_step
,
batch
)
def
distributed_
prediction
_steps
(
self
,
batch
):
logits
=
self
.
args
.
strategy
.
run
(
self
.
prediction
_step
,
batch
)
return
logits
...
...
@@ -446,7 +477,7 @@ class TFTrainer:
logs
[
"loss"
]
=
training_loss
.
numpy
()
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_
log
(
logs
)
self
.
log
(
logs
)
if
self
.
global_step
==
1
and
self
.
args
.
debug
:
with
self
.
tb_writer
.
as_default
():
...
...
@@ -469,7 +500,7 @@ class TFTrainer:
logs
[
"learning_rate"
]
=
self
.
lr_scheduler
(
self
.
global_step
).
numpy
()
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_
log
(
logs
)
self
.
log
(
logs
)
if
self
.
global_step
>
0
and
self
.
global_step
%
self
.
args
.
save_steps
==
0
:
ckpt_save_path
=
self
.
model
.
ckpt_manager
.
save
()
...
...
@@ -490,7 +521,12 @@ class TFTrainer:
delattr
(
self
,
"_past"
)
def
training_step
(
self
,
features
,
labels
):
per_example_loss
,
_
=
self
.
_run_model
(
features
,
labels
,
True
)
"""
Perform a training step on features and labels.
Subclass and override to inject some custom behavior.
"""
per_example_loss
,
_
=
self
.
run_model
(
features
,
labels
,
True
)
scaled_loss
=
per_example_loss
/
self
.
total_train_batch_size
gradients
=
tf
.
gradients
(
scaled_loss
,
self
.
model
.
trainable_variables
)
gradients
=
[
...
...
@@ -534,14 +570,24 @@ class TFTrainer:
with
self
.
args
.
strategy
.
scope
():
self
.
args
.
strategy
.
run
(
self
.
apply_gradients
,
batch
)
def
_
run_model
(
self
,
features
,
labels
,
training
):
def
run_model
(
self
,
features
,
labels
,
training
):
"""
Computes the loss of the given features and labels pair.
Subclass and override this method if you want to inject some custom behavior.
Args:
features: the batched features.
labels: the batched labels.
training: run the model in training mode or not
"""
if
hasattr
(
self
,
"_run_model"
):
warnings
.
warn
(
"The `_run_model` method is deprecated and won't be called in a future version, define `run_model` in your subclass."
,
FutureWarning
,
)
return
self
.
_run_model
(
features
,
labels
,
training
)
if
self
.
args
.
past_index
>=
0
and
getattr
(
self
,
"_past"
,
None
)
is
not
None
:
features
[
"mems"
]
=
self
.
_past
...
...
@@ -578,7 +624,7 @@ class TFTrainer:
"""
test_ds
,
steps
,
num_examples
=
self
.
get_test_tfdataset
(
test_dataset
)
return
self
.
_
prediction_loop
(
test_ds
,
steps
,
num_examples
,
description
=
"Prediction"
)
return
self
.
prediction_loop
(
test_ds
,
steps
,
num_examples
,
description
=
"Prediction"
)
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment