Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fdccf82e
Unverified
Commit
fdccf82e
authored
Sep 30, 2020
by
Sylvain Gugger
Committed by
GitHub
Sep 30, 2020
Browse files
Remove config assumption in Trainer (#7464)
* Remove config assumption in Trainer * Initialize for eval
parent
cc4eff80
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
24 deletions
+17
-24
src/transformers/trainer.py
src/transformers/trainer.py
+16
-24
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+1
-0
No files found.
src/transformers/trainer.py
View file @
fdccf82e
...
...
@@ -282,7 +282,7 @@ class Trainer:
# Create output directory if needed
if
self
.
is_world_process_zero
():
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
()
and
isinstance
(
self
.
model
,
PreTrainedModel
)
:
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self
.
model
.
config
.
xla_device
=
True
...
...
@@ -490,11 +490,9 @@ class Trainer:
logger
.
info
(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
try
:
combined_dict
=
{
**
self
.
model
.
config
.
to_dict
(),
**
self
.
args
.
to_sanitized_dict
()}
except
AttributeError
:
# in case the model has no config
combined_dict
=
{
**
self
.
args
.
to_sanitized_dict
()}
if
isinstance
(
self
.
model
,
PreTrainedModel
):
combined_dict
=
{
**
self
.
model
.
config
.
to_dict
(),
**
combined_dict
}
wandb
.
init
(
project
=
os
.
getenv
(
"WANDB_PROJECT"
,
"huggingface"
),
config
=
combined_dict
,
name
=
self
.
args
.
run_name
)
...
...
@@ -533,6 +531,7 @@ class Trainer:
if
experiment
is
not
None
:
experiment
.
_set_model_graph
(
self
.
model
,
framework
=
"transformers"
)
experiment
.
_log_parameters
(
self
.
args
,
prefix
=
"args/"
,
framework
=
"transformers"
)
if
isinstance
(
self
.
model
,
PreTrainedModel
):
experiment
.
_log_parameters
(
self
.
model
.
config
,
prefix
=
"config/"
,
framework
=
"transformers"
)
def
num_examples
(
self
,
dataloader
:
DataLoader
)
->
int
:
...
...
@@ -679,7 +678,11 @@ class Trainer:
model
,
device_ids
=
[
self
.
args
.
local_rank
],
output_device
=
self
.
args
.
local_rank
,
find_unused_parameters
=
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
),
find_unused_parameters
=
(
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
)
if
isinstance
(
model
,
PreTrainedModel
)
else
True
),
)
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
...
...
@@ -707,15 +710,14 @@ class Trainer:
self
.
global_step
=
0
self
.
epoch
=
0
self
.
total_flos
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
model_path
is
not
None
:
# set global_step to global_step of last saved checkpoint from model path
try
:
self
.
global_step
=
int
(
model_path
.
split
(
"-"
)[
-
1
].
split
(
os
.
path
.
sep
)[
0
])
self
.
total_flos
=
getattr
(
self
.
_actual_model
(
model
).
config
,
"total_flos"
,
0
)
epochs_trained
=
self
.
global_step
//
num_update_steps_per_epoch
steps_trained_in_current_epoch
=
self
.
global_step
%
(
num_update_steps_per_epoch
)
...
...
@@ -723,14 +725,13 @@ class Trainer:
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Continuing training from %d non-embedding floating-point operations"
,
self
.
total_flos
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
except
ValueError
:
self
.
global_step
=
0
self
.
total_flos
=
0
logger
.
info
(
" Starting fine-tuning."
)
tr_loss
=
torch
.
tensor
(
0.0
).
to
(
self
.
args
.
device
)
self
.
total_flos
=
self
.
state
.
total_flos
logging_loss_scalar
=
0.0
model
.
zero_grad
()
disable_tqdm
=
self
.
args
.
disable_tqdm
or
not
self
.
is_local_process_zero
()
...
...
@@ -1029,7 +1030,7 @@ class Trainer:
else
:
total_flos
=
self
.
total_flos
if
total_flos
>
0
:
logs
[
"total_flos"
]
=
self
.
total_flos
logs
[
"total_flos"
]
=
total_flos
if
self
.
global_step
is
None
:
# when logging evaluation metrics without training
self
.
global_step
=
0
...
...
@@ -1245,11 +1246,9 @@ class Trainer:
# Storing the number of floating-point operations that went into the model
if
self
.
total_flos
is
not
None
:
if
self
.
args
.
local_rank
!=
-
1
:
total_flos
=
distributed_broadcast_scalars
([
self
.
total_flos
]).
sum
().
item
()
self
.
state
.
total_flos
=
distributed_broadcast_scalars
([
self
.
total_flos
]).
sum
().
item
()
else
:
total_flos
=
self
.
total_flos
if
total_flos
>
0
:
self
.
model
.
config
.
total_flos
=
total_flos
self
.
state
.
total_flos
=
self
.
total_flos
def
_sorted_checkpoints
(
self
,
checkpoint_prefix
=
PREFIX_CHECKPOINT_DIR
,
use_mtime
=
False
)
->
List
[
str
]:
ordering_and_checkpoint_path
=
[]
...
...
@@ -1363,13 +1362,6 @@ class Trainer:
prediction_loss_only
if
prediction_loss_only
is
not
None
else
self
.
args
.
prediction_loss_only
)
assert
not
getattr
(
self
.
model
.
config
,
"output_attentions"
,
False
),
"The prediction loop does not work with `output_attentions=True`."
assert
not
getattr
(
self
.
model
.
config
,
"output_hidden_states"
,
False
),
"The prediction loop does not work with `output_hidden_states=True`."
model
=
self
.
model
# multi-gpu eval
if
self
.
args
.
n_gpu
>
1
:
...
...
src/transformers/trainer_utils.py
View file @
fdccf82e
...
...
@@ -224,6 +224,7 @@ class TrainerState:
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""
total_flos
:
int
=
0
best_metric
:
Optional
[
float
]
=
None
best_model_checkpoint
:
Optional
[
str
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment