Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d63ab615
Unverified
Commit
d63ab615
authored
Jan 25, 2021
by
Kai Fricke
Committed by
GitHub
Jan 25, 2021
Browse files
Use object store to pass trainer object to Ray Tune (#9749)
parent
6312fed4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
9 deletions
+14
-9
src/transformers/integrations.py
src/transformers/integrations.py
+14
-9
No files found.
src/transformers/integrations.py
View file @
d63ab615
...
...
@@ -149,20 +149,20 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def
run_hp_search_ray
(
trainer
,
n_trials
:
int
,
direction
:
str
,
**
kwargs
)
->
BestRun
:
import
ray
def
_objective
(
trial
,
checkpoint_dir
=
None
):
def
_objective
(
trial
,
local_trainer
,
checkpoint_dir
=
None
):
model_path
=
None
if
checkpoint_dir
:
for
subdir
in
os
.
listdir
(
checkpoint_dir
):
if
subdir
.
startswith
(
PREFIX_CHECKPOINT_DIR
):
model_path
=
os
.
path
.
join
(
checkpoint_dir
,
subdir
)
trainer
.
objective
=
None
trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
local_
trainer
.
objective
=
None
local_
trainer
.
train
(
model_path
=
model_path
,
trial
=
trial
)
# If there hasn't been any evaluation during the training loop.
if
getattr
(
trainer
,
"objective"
,
None
)
is
None
:
metrics
=
trainer
.
evaluate
()
trainer
.
objective
=
trainer
.
compute_objective
(
metrics
)
trainer
.
_tune_save_checkpoint
()
ray
.
tune
.
report
(
objective
=
trainer
.
objective
,
**
metrics
,
done
=
True
)
if
getattr
(
local_
trainer
,
"objective"
,
None
)
is
None
:
metrics
=
local_
trainer
.
evaluate
()
local_
trainer
.
objective
=
local_
trainer
.
compute_objective
(
metrics
)
local_
trainer
.
_tune_save_checkpoint
()
ray
.
tune
.
report
(
objective
=
local_
trainer
.
objective
,
**
metrics
,
done
=
True
)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
...
...
@@ -217,7 +217,12 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
"Trainer `args`."
.
format
(
cls
=
type
(
kwargs
[
"scheduler"
]).
__name__
)
)
analysis
=
ray
.
tune
.
run
(
_objective
,
config
=
trainer
.
hp_space
(
None
),
num_samples
=
n_trials
,
**
kwargs
)
analysis
=
ray
.
tune
.
run
(
ray
.
tune
.
with_parameters
(
_objective
,
local_trainer
=
trainer
),
config
=
trainer
.
hp_space
(
None
),
num_samples
=
n_trials
,
**
kwargs
,
)
best_trial
=
analysis
.
get_best_trial
(
metric
=
"objective"
,
mode
=
direction
[:
3
])
best_run
=
BestRun
(
best_trial
.
trial_id
,
best_trial
.
last_result
[
"objective"
],
best_trial
.
config
)
if
_tb_writer
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment