Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fdccf82e
Unverified
Commit
fdccf82e
authored
Sep 30, 2020
by
Sylvain Gugger
Committed by
GitHub
Sep 30, 2020
Browse files
Remove config assumption in Trainer (#7464)
* Remove config assumption in Trainer * Initialize for eval
parent
cc4eff80
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
24 deletions
+17
-24
src/transformers/trainer.py
src/transformers/trainer.py
+16
-24
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+1
-0
No files found.
src/transformers/trainer.py
View file @
fdccf82e
...
@@ -282,7 +282,7 @@ class Trainer:
...
@@ -282,7 +282,7 @@ class Trainer:
# Create output directory if needed
# Create output directory if needed
if
self
.
is_world_process_zero
():
if
self
.
is_world_process_zero
():
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
()
and
isinstance
(
self
.
model
,
PreTrainedModel
)
:
# Set an xla_device flag on the model's config.
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
# We'll find a more elegant and not need to do this in the future.
self
.
model
.
config
.
xla_device
=
True
self
.
model
.
config
.
xla_device
=
True
...
@@ -490,11 +490,9 @@ class Trainer:
...
@@ -490,11 +490,9 @@ class Trainer:
logger
.
info
(
logger
.
info
(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
)
try
:
combined_dict
=
{
**
self
.
args
.
to_sanitized_dict
()}
combined_dict
=
{
**
self
.
model
.
config
.
to_dict
(),
**
self
.
args
.
to_sanitized_dict
()}
if
isinstance
(
self
.
model
,
PreTrainedModel
):
except
AttributeError
:
combined_dict
=
{
**
self
.
model
.
config
.
to_dict
(),
**
combined_dict
}
# in case the model has no config
combined_dict
=
{
**
self
.
args
.
to_sanitized_dict
()}
wandb
.
init
(
wandb
.
init
(
project
=
os
.
getenv
(
"WANDB_PROJECT"
,
"huggingface"
),
config
=
combined_dict
,
name
=
self
.
args
.
run_name
project
=
os
.
getenv
(
"WANDB_PROJECT"
,
"huggingface"
),
config
=
combined_dict
,
name
=
self
.
args
.
run_name
)
)
...
@@ -533,7 +531,8 @@ class Trainer:
...
@@ -533,7 +531,8 @@ class Trainer:
if
experiment
is
not
None
:
if
experiment
is
not
None
:
experiment
.
_set_model_graph
(
self
.
model
,
framework
=
"transformers"
)
experiment
.
_set_model_graph
(
self
.
model
,
framework
=
"transformers"
)
experiment
.
_log_parameters
(
self
.
args
,
prefix
=
"args/"
,
framework
=
"transformers"
)
experiment
.
_log_parameters
(
self
.
args
,
prefix
=
"args/"
,
framework
=
"transformers"
)
experiment
.
_log_parameters
(
self
.
model
.
config
,
prefix
=
"config/"
,
framework
=
"transformers"
)
if
isinstance
(
self
.
model
,
PreTrainedModel
):
experiment
.
_log_parameters
(
self
.
model
.
config
,
prefix
=
"config/"
,
framework
=
"transformers"
)
def
num_examples
(
self
,
dataloader
:
DataLoader
)
->
int
:
def
num_examples
(
self
,
dataloader
:
DataLoader
)
->
int
:
"""
"""
...
@@ -679,7 +678,11 @@ class Trainer:
...
@@ -679,7 +678,11 @@ class Trainer:
model
,
model
,
device_ids
=
[
self
.
args
.
local_rank
],
device_ids
=
[
self
.
args
.
local_rank
],
output_device
=
self
.
args
.
local_rank
,
output_device
=
self
.
args
.
local_rank
,
find_unused_parameters
=
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
),
find_unused_parameters
=
(
not
getattr
(
model
.
config
,
"gradient_checkpointing"
,
False
)
if
isinstance
(
model
,
PreTrainedModel
)
else
True
),
)
)
# find_unused_parameters breaks checkpointing as per
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
...
@@ -707,15 +710,14 @@ class Trainer:
...
@@ -707,15 +710,14 @@ class Trainer:
self
.
global_step
=
0
self
.
global_step
=
0
self
.
epoch
=
0
self
.
epoch
=
0
self
.
total_flos
=
0
epochs_trained
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
# Check if continuing training from a checkpoint
if
model_path
is
not
None
:
if
model_path
is
not
None
:
# set global_step to global_step of last saved checkpoint from model path
# set global_step to global_step of last saved checkpoint from model path
try
:
try
:
self
.
global_step
=
int
(
model_path
.
split
(
"-"
)[
-
1
].
split
(
os
.
path
.
sep
)[
0
])
self
.
global_step
=
int
(
model_path
.
split
(
"-"
)[
-
1
].
split
(
os
.
path
.
sep
)[
0
])
self
.
total_flos
=
getattr
(
self
.
_actual_model
(
model
).
config
,
"total_flos"
,
0
)
epochs_trained
=
self
.
global_step
//
num_update_steps_per_epoch
epochs_trained
=
self
.
global_step
//
num_update_steps_per_epoch
steps_trained_in_current_epoch
=
self
.
global_step
%
(
num_update_steps_per_epoch
)
steps_trained_in_current_epoch
=
self
.
global_step
%
(
num_update_steps_per_epoch
)
...
@@ -723,14 +725,13 @@ class Trainer:
...
@@ -723,14 +725,13 @@ class Trainer:
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Continuing training from %d non-embedding floating-point operations"
,
self
.
total_flos
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
except
ValueError
:
except
ValueError
:
self
.
global_step
=
0
self
.
global_step
=
0
self
.
total_flos
=
0
logger
.
info
(
" Starting fine-tuning."
)
logger
.
info
(
" Starting fine-tuning."
)
tr_loss
=
torch
.
tensor
(
0.0
).
to
(
self
.
args
.
device
)
tr_loss
=
torch
.
tensor
(
0.0
).
to
(
self
.
args
.
device
)
self
.
total_flos
=
self
.
state
.
total_flos
logging_loss_scalar
=
0.0
logging_loss_scalar
=
0.0
model
.
zero_grad
()
model
.
zero_grad
()
disable_tqdm
=
self
.
args
.
disable_tqdm
or
not
self
.
is_local_process_zero
()
disable_tqdm
=
self
.
args
.
disable_tqdm
or
not
self
.
is_local_process_zero
()
...
@@ -1029,7 +1030,7 @@ class Trainer:
...
@@ -1029,7 +1030,7 @@ class Trainer:
else
:
else
:
total_flos
=
self
.
total_flos
total_flos
=
self
.
total_flos
if
total_flos
>
0
:
if
total_flos
>
0
:
logs
[
"total_flos"
]
=
self
.
total_flos
logs
[
"total_flos"
]
=
total_flos
if
self
.
global_step
is
None
:
if
self
.
global_step
is
None
:
# when logging evaluation metrics without training
# when logging evaluation metrics without training
self
.
global_step
=
0
self
.
global_step
=
0
...
@@ -1245,11 +1246,9 @@ class Trainer:
...
@@ -1245,11 +1246,9 @@ class Trainer:
# Storing the number of floating-point operations that went into the model
# Storing the number of floating-point operations that went into the model
if
self
.
total_flos
is
not
None
:
if
self
.
total_flos
is
not
None
:
if
self
.
args
.
local_rank
!=
-
1
:
if
self
.
args
.
local_rank
!=
-
1
:
total_flos
=
distributed_broadcast_scalars
([
self
.
total_flos
]).
sum
().
item
()
self
.
state
.
total_flos
=
distributed_broadcast_scalars
([
self
.
total_flos
]).
sum
().
item
()
else
:
else
:
total_flos
=
self
.
total_flos
self
.
state
.
total_flos
=
self
.
total_flos
if
total_flos
>
0
:
self
.
model
.
config
.
total_flos
=
total_flos
def
_sorted_checkpoints
(
self
,
checkpoint_prefix
=
PREFIX_CHECKPOINT_DIR
,
use_mtime
=
False
)
->
List
[
str
]:
def
_sorted_checkpoints
(
self
,
checkpoint_prefix
=
PREFIX_CHECKPOINT_DIR
,
use_mtime
=
False
)
->
List
[
str
]:
ordering_and_checkpoint_path
=
[]
ordering_and_checkpoint_path
=
[]
...
@@ -1363,13 +1362,6 @@ class Trainer:
...
@@ -1363,13 +1362,6 @@ class Trainer:
prediction_loss_only
if
prediction_loss_only
is
not
None
else
self
.
args
.
prediction_loss_only
prediction_loss_only
if
prediction_loss_only
is
not
None
else
self
.
args
.
prediction_loss_only
)
)
assert
not
getattr
(
self
.
model
.
config
,
"output_attentions"
,
False
),
"The prediction loop does not work with `output_attentions=True`."
assert
not
getattr
(
self
.
model
.
config
,
"output_hidden_states"
,
False
),
"The prediction loop does not work with `output_hidden_states=True`."
model
=
self
.
model
model
=
self
.
model
# multi-gpu eval
# multi-gpu eval
if
self
.
args
.
n_gpu
>
1
:
if
self
.
args
.
n_gpu
>
1
:
...
...
src/transformers/trainer_utils.py
View file @
fdccf82e
...
@@ -224,6 +224,7 @@ class TrainerState:
...
@@ -224,6 +224,7 @@ class TrainerState:
A class containing the `Trainer` fields that will be saved along the model and optimizer.
A class containing the `Trainer` fields that will be saved along the model and optimizer.
"""
"""
total_flos
:
int
=
0
best_metric
:
Optional
[
float
]
=
None
best_metric
:
Optional
[
float
]
=
None
best_model_checkpoint
:
Optional
[
str
]
=
None
best_model_checkpoint
:
Optional
[
str
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment