Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a7da2996
Unverified
Commit
a7da2996
authored
Aug 10, 2023
by
Zach Mueller
Committed by
GitHub
Aug 10, 2023
Browse files
Fix issue with ratio evaluation steps and auto find batch size (#25436)
* Fully rebased solution * 500
parent
2d6839ea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
13 deletions
+31
-13
src/transformers/integrations.py
src/transformers/integrations.py
+1
-1
src/transformers/trainer.py
src/transformers/trainer.py
+17
-8
src/transformers/trainer_callback.py
src/transformers/trainer_callback.py
+13
-4
No files found.
src/transformers/integrations.py
View file @
a7da2996
...
...
@@ -746,7 +746,7 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU
_watch_model
=
os
.
getenv
(
"WANDB_WATCH"
,
"false"
)
if
not
is_torch_tpu_available
()
and
_watch_model
in
(
"all"
,
"parameters"
,
"gradients"
):
self
.
_wandb
.
watch
(
model
,
log
=
_watch_model
,
log_freq
=
max
(
100
,
args
.
logging_steps
))
self
.
_wandb
.
watch
(
model
,
log
=
_watch_model
,
log_freq
=
max
(
100
,
state
.
logging_steps
))
def
on_train_begin
(
self
,
args
,
state
,
control
,
model
=
None
,
**
kwargs
):
if
self
.
_wandb
is
None
:
...
...
src/transformers/trainer.py
View file @
a7da2996
...
...
@@ -1586,14 +1586,6 @@ class Trainer:
f
"
{
args
.
max_steps
}
"
)
# Compute absolute values for logging, eval, and save if given as ratio
if
args
.
logging_steps
and
args
.
logging_steps
<
1
:
args
.
logging_steps
=
math
.
ceil
(
max_steps
*
args
.
logging_steps
)
if
args
.
eval_steps
and
args
.
eval_steps
<
1
:
args
.
eval_steps
=
math
.
ceil
(
max_steps
*
args
.
eval_steps
)
if
args
.
save_steps
and
args
.
save_steps
<
1
:
args
.
save_steps
=
math
.
ceil
(
max_steps
*
args
.
save_steps
)
if
DebugOption
.
UNDERFLOW_OVERFLOW
in
self
.
args
.
debug
:
if
self
.
args
.
n_gpu
>
1
:
# nn.DataParallel(model) replicates the model, creating new variables and module
...
...
@@ -1627,6 +1619,23 @@ class Trainer:
self
.
state
=
TrainerState
()
self
.
state
.
is_hyper_param_search
=
trial
is
not
None
# Compute absolute values for logging, eval, and save if given as ratio
if
args
.
logging_steps
is
not
None
:
if
args
.
logging_steps
<
1
:
self
.
state
.
logging_steps
=
math
.
ceil
(
max_steps
*
args
.
logging_steps
)
else
:
self
.
state
.
logging_steps
=
args
.
logging_steps
if
args
.
eval_steps
is
not
None
:
if
args
.
eval_steps
<
1
:
self
.
state
.
eval_steps
=
math
.
ceil
(
max_steps
*
args
.
eval_steps
)
else
:
self
.
state
.
eval_steps
=
args
.
eval_steps
if
args
.
save_steps
is
not
None
:
if
args
.
save_steps
<
1
:
self
.
state
.
save_steps
=
math
.
ceil
(
max_steps
*
args
.
save_steps
)
else
:
self
.
state
.
save_steps
=
args
.
save_steps
# Activate gradient checkpointing if needed
if
args
.
gradient_checkpointing
:
self
.
model
.
gradient_checkpointing_enable
()
...
...
src/transformers/trainer_callback.py
View file @
a7da2996
...
...
@@ -53,6 +53,12 @@ class TrainerState:
During training, represents the number of update steps completed.
max_steps (`int`, *optional*, defaults to 0):
The number of update steps to do during the current training.
logging_steps (`int`, *optional*, defaults to 500):
Log every X updates steps
eval_steps (`int`, *optional*):
Run an evaluation every X steps.
save_steps (`int`, *optional*, defaults to 500):
Save checkpoint every X updates steps.
total_flos (`float`, *optional*, defaults to 0):
The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
...
...
@@ -77,6 +83,9 @@ class TrainerState:
epoch
:
Optional
[
float
]
=
None
global_step
:
int
=
0
max_steps
:
int
=
0
logging_steps
:
int
=
500
eval_steps
:
int
=
500
save_steps
:
int
=
500
num_train_epochs
:
int
=
0
total_flos
:
float
=
0
log_history
:
List
[
Dict
[
str
,
float
]]
=
None
...
...
@@ -421,13 +430,13 @@ class DefaultFlowCallback(TrainerCallback):
# Log
if
state
.
global_step
==
1
and
args
.
logging_first_step
:
control
.
should_log
=
True
if
args
.
logging_strategy
==
IntervalStrategy
.
STEPS
and
state
.
global_step
%
args
.
logging_steps
==
0
:
if
args
.
logging_strategy
==
IntervalStrategy
.
STEPS
and
state
.
global_step
%
state
.
logging_steps
==
0
:
control
.
should_log
=
True
# Evaluate
if
(
args
.
evaluation_strategy
==
IntervalStrategy
.
STEPS
and
state
.
global_step
%
args
.
eval_steps
==
0
and
state
.
global_step
%
state
.
eval_steps
==
0
and
args
.
eval_delay
<=
state
.
global_step
):
control
.
should_evaluate
=
True
...
...
@@ -435,8 +444,8 @@ class DefaultFlowCallback(TrainerCallback):
# Save
if
(
args
.
save_strategy
==
IntervalStrategy
.
STEPS
and
args
.
save_steps
>
0
and
state
.
global_step
%
args
.
save_steps
==
0
and
state
.
save_steps
>
0
and
state
.
global_step
%
state
.
save_steps
==
0
):
control
.
should_save
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment