Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a59bcefb
Unverified
Commit
a59bcefb
authored
Aug 31, 2020
by
Sylvain Gugger
Committed by
GitHub
Aug 31, 2020
Browse files
Split hp search methods (#6857)
* Split the run_hp_search by backend * Unused import
parent
23f9611c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
73 deletions
+83
-73
src/transformers/integrations.py
src/transformers/integrations.py
+79
-71
src/transformers/trainer.py
src/transformers/trainer.py
+4
-2
No files found.
src/transformers/integrations.py
View file @
a59bcefb
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
numpy
as
np
import
numpy
as
np
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
BestRun
,
HPSearchBackend
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
BestRun
from
transformers.utils
import
logging
from
transformers.utils
import
logging
...
@@ -83,7 +83,7 @@ def default_hp_search_backend():
...
@@ -83,7 +83,7 @@ def default_hp_search_backend():
return
"ray"
return
"ray"
def
run_hp_search
(
trainer
,
n_trials
,
direction
,
kwargs
)
:
def
run_hp_search
_optuna
(
trainer
,
n_trials
:
int
,
direction
:
str
,
**
kwargs
)
->
BestRun
:
def
_objective
(
trial
,
checkpoint_dir
=
None
):
def
_objective
(
trial
,
checkpoint_dir
=
None
):
model_path
=
None
model_path
=
None
if
checkpoint_dir
:
if
checkpoint_dir
:
...
@@ -96,80 +96,88 @@ def run_hp_search(trainer, n_trials, direction, kwargs):
...
@@ -96,80 +96,88 @@ def run_hp_search(trainer, n_trials, direction, kwargs):
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
metrics
=
trainer
.
evaluate
()
metrics
=
trainer
.
evaluate
()
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
if
trainer
.
hp_search_backend
==
HPSearchBackend
.
RAY
:
trainer
.
_tune_save_checkpoint
()
ray
.
tune
.
report
(
objective
=
trainer
.
objective
)
return
trainer
.
objective
return
trainer
.
objective
if
trainer
.
hp_search_backend
==
HPSearchBackend
.
OPTUNA
:
timeout
=
kwargs
.
pop
(
"timeout"
,
None
)
timeout
=
kwargs
.
pop
(
"timeout"
,
None
)
n_jobs
=
kwargs
.
pop
(
"n_jobs"
,
1
)
n_jobs
=
kwargs
.
pop
(
"n_jobs"
,
1
)
study
=
optuna
.
create_study
(
direction
=
direction
,
**
kwargs
)
study
=
optuna
.
create_study
(
direction
=
direction
,
**
kwargs
)
study
.
optimize
(
_objective
,
n_trials
=
n_trials
,
timeout
=
timeout
,
n_jobs
=
n_jobs
)
study
.
optimize
(
_objective
,
n_trials
=
n_trials
,
timeout
=
timeout
,
n_jobs
=
n_jobs
)
best_trial
=
study
.
best_trial
best_trial
=
study
.
best_trial
return
BestRun
(
str
(
best_trial
.
number
),
best_trial
.
value
,
best_trial
.
params
)
best_run
=
BestRun
(
str
(
best_trial
.
number
),
best_trial
.
value
,
best_trial
.
params
)
elif
trainer
.
hp_search_backend
==
HPSearchBackend
.
RAY
:
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
def
run_hp_search_ray
(
trainer
,
n_trials
:
int
,
direction
:
str
,
**
kwargs
)
->
BestRun
:
# while doing the ray hp search.
def
_objective
(
trial
,
checkpoint_dir
=
None
):
_tb_writer
=
trainer
.
tb_writer
model_path
=
None
trainer
.
tb_writer
=
None
if
checkpoint_dir
:
trainer
.
model
=
None
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
# Setup default `resources_per_trial` and `reporter`.
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
if
"resources_per_trial"
not
in
kwargs
and
trainer
.
args
.
n_gpu
>
0
:
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
# `args.n_gpu` is considered the total number of GPUs that will be split
trainer
.
objective
=
None
# among the `n_jobs`
trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
n_jobs
=
int
(
kwargs
.
pop
(
"n_jobs"
,
1
))
# If there hasn't been any evaluation during the training loop.
num_gpus_per_trial
=
trainer
.
args
.
n_gpu
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
if
num_gpus_per_trial
/
n_jobs
>=
1
:
metrics
=
trainer
.
evaluate
()
num_gpus_per_trial
=
int
(
np
.
ceil
(
num_gpus_per_trial
/
n_jobs
))
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
kwargs
[
"resources_per_trial"
]
=
{
"gpu"
:
num_gpus_per_trial
}
trainer
.
_tune_save_checkpoint
()
ray
.
tune
.
report
(
objective
=
trainer
.
objective
)
if
"reporter"
not
in
kwargs
:
return
trainer
.
objective
from
ray.tune
import
CLIReporter
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
kwargs
[
"progress_reporter"
]
=
CLIReporter
(
metric_columns
=
[
"objective"
])
# while doing the ray hp search.
if
"keep_checkpoints_num"
in
kwargs
and
kwargs
[
"keep_checkpoints_num"
]
>
0
:
_tb_writer
=
trainer
.
tb_writer
# `keep_checkpoints_num=0` would disabled checkpointing
trainer
.
tb_writer
=
None
trainer
.
use_tune_checkpoints
=
True
trainer
.
model
=
None
if
kwargs
[
"keep_checkpoints_num"
]
>
1
:
# Setup default `resources_per_trial` and `reporter`.
logger
.
warning
(
if
"resources_per_trial"
not
in
kwargs
and
trainer
.
args
.
n_gpu
>
0
:
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
# `args.n_gpu` is considered the total number of GPUs that will be split
"consider setting `keep_checkpoints_num=1`."
# among the `n_jobs`
)
n_jobs
=
int
(
kwargs
.
pop
(
"n_jobs"
,
1
))
if
"scheduler"
in
kwargs
:
num_gpus_per_trial
=
trainer
.
args
.
n_gpu
from
ray.tune.schedulers
import
(
if
num_gpus_per_trial
/
n_jobs
>=
1
:
ASHAScheduler
,
num_gpus_per_trial
=
int
(
np
.
ceil
(
num_gpus_per_trial
/
n_jobs
))
HyperBandForBOHB
,
kwargs
[
"resources_per_trial"
]
=
{
"gpu"
:
num_gpus_per_trial
}
MedianStoppingRule
,
PopulationBasedTraining
,
if
"reporter"
not
in
kwargs
:
from
ray.tune
import
CLIReporter
kwargs
[
"progress_reporter"
]
=
CLIReporter
(
metric_columns
=
[
"objective"
])
if
"keep_checkpoints_num"
in
kwargs
and
kwargs
[
"keep_checkpoints_num"
]
>
0
:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer
.
use_tune_checkpoints
=
True
if
kwargs
[
"keep_checkpoints_num"
]
>
1
:
logger
.
warning
(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
)
if
"scheduler"
in
kwargs
:
from
ray.tune.schedulers
import
ASHAScheduler
,
HyperBandForBOHB
,
MedianStoppingRule
,
PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
# Check if checkpointing is enabled for PopulationBasedTraining
if
isinstance
(
kwargs
[
"scheduler"
],
PopulationBasedTraining
):
if
isinstance
(
kwargs
[
"scheduler"
],
PopulationBasedTraining
):
if
not
trainer
.
use_tune_checkpoints
:
if
not
trainer
.
use_tune_checkpoints
:
logger
.
warning
(
logger
.
warning
(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if
isinstance
(
kwargs
[
"scheduler"
],
(
ASHAScheduler
,
MedianStoppingRule
,
HyperBandForBOHB
,
PopulationBasedTraining
)
)
and
(
not
trainer
.
args
.
do_eval
or
not
trainer
.
args
.
evaluate_during_training
):
raise
RuntimeError
(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`."
.
format
(
cls
=
type
(
kwargs
[
"scheduler"
]).
__name__
)
)
)
analysis
=
ray
.
tune
.
run
(
_objective
,
config
=
trainer
.
hp_space
(
None
),
num_samples
=
n_trials
,
**
kwargs
)
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
best_trial
=
analysis
.
get_best_trial
(
metric
=
"objective"
,
mode
=
direction
[:
3
])
if
isinstance
(
best_run
=
BestRun
(
best_trial
.
trial_id
,
best_trial
.
last_result
[
"objective"
],
best_trial
.
config
)
kwargs
[
"scheduler"
],
(
ASHAScheduler
,
MedianStoppingRule
,
HyperBandForBOHB
,
PopulationBasedTraining
)
trainer
.
tb_writer
=
_tb_writer
)
and
(
not
trainer
.
args
.
do_eval
or
not
trainer
.
args
.
evaluate_during_training
):
raise
RuntimeError
(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`."
.
format
(
cls
=
type
(
kwargs
[
"scheduler"
]).
__name__
)
)
analysis
=
ray
.
tune
.
run
(
_objective
,
config
=
trainer
.
hp_space
(
None
),
num_samples
=
n_trials
,
**
kwargs
)
best_trial
=
analysis
.
get_best_trial
(
metric
=
"objective"
,
mode
=
direction
[:
3
])
best_run
=
BestRun
(
best_trial
.
trial_id
,
best_trial
.
last_result
[
"objective"
],
best_trial
.
config
)
trainer
.
tb_writer
=
_tb_writer
return
best_run
return
best_run
src/transformers/trainer.py
View file @
a59bcefb
...
@@ -27,7 +27,8 @@ from .integrations import (
...
@@ -27,7 +27,8 @@ from .integrations import (
is_ray_available
,
is_ray_available
,
is_tensorboard_available
,
is_tensorboard_available
,
is_wandb_available
,
is_wandb_available
,
run_hp_search
,
run_hp_search_optuna
,
run_hp_search_ray
,
)
)
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
from
.optimization
import
AdamW
,
get_linear_schedule_with_warmup
...
@@ -884,7 +885,8 @@ class Trainer:
...
@@ -884,7 +885,8 @@ class Trainer:
self
.
hp_space
=
default_hp_space
[
backend
]
if
hp_space
is
None
else
hp_space
self
.
hp_space
=
default_hp_space
[
backend
]
if
hp_space
is
None
else
hp_space
self
.
compute_objective
=
default_compute_objective
if
compute_objective
is
None
else
compute_objective
self
.
compute_objective
=
default_compute_objective
if
compute_objective
is
None
else
compute_objective
best_run
=
run_hp_search
(
self
,
n_trials
,
direction
,
kwargs
)
run_hp_search
=
run_hp_search_optuna
if
backend
==
HPSearchBackend
.
OPTUNA
else
run_hp_search_ray
best_run
=
run_hp_search
(
self
,
n_trials
,
direction
,
**
kwargs
)
self
.
hp_search_backend
=
None
self
.
hp_search_backend
=
None
return
best_run
return
best_run
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment