Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98569d4b
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "be9438ed43dc4027d3c59af8405e3996ad3d9825"
Unverified
Commit
98569d4b
authored
Feb 26, 2021
by
Kai Fricke
Committed by
GitHub
Feb 26, 2021
Browse files
Add Ray Tune hyperparameter search integration test (#10414)
parent
d03695f3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
2 deletions
+56
-2
tests/test_trainer.py
tests/test_trainer.py
+56
-2
No files found.
tests/test_trainer.py
View file @
98569d4b
...
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
...
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
get_tests_dir
,
get_tests_dir
,
require_datasets
,
require_datasets
,
require_optuna
,
require_optuna
,
require_ray
,
require_sentencepiece
,
require_sentencepiece
,
require_tokenizers
,
require_tokenizers
,
require_torch
,
require_torch
,
...
@@ -80,6 +81,12 @@ class RegressionDataset:
...
@@ -80,6 +81,12 @@ class RegressionDataset:
return
result
return
result
@
dataclasses
.
dataclass
class
RegressionTrainingArguments
(
TrainingArguments
):
a
:
float
=
0.0
b
:
float
=
0.0
class
RepeatDataset
:
class
RepeatDataset
:
def
__init__
(
self
,
x
,
length
=
64
):
def
__init__
(
self
,
x
,
length
=
64
):
self
.
x
=
x
self
.
x
=
x
...
@@ -200,7 +207,8 @@ if is_torch_available():
...
@@ -200,7 +207,8 @@ if is_torch_available():
optimizers
=
kwargs
.
pop
(
"optimizers"
,
(
None
,
None
))
optimizers
=
kwargs
.
pop
(
"optimizers"
,
(
None
,
None
))
output_dir
=
kwargs
.
pop
(
"output_dir"
,
"./regression"
)
output_dir
=
kwargs
.
pop
(
"output_dir"
,
"./regression"
)
model_init
=
kwargs
.
pop
(
"model_init"
,
None
)
model_init
=
kwargs
.
pop
(
"model_init"
,
None
)
args
=
TrainingArguments
(
output_dir
,
**
kwargs
)
args
=
RegressionTrainingArguments
(
output_dir
,
a
=
a
,
b
=
b
,
**
kwargs
)
return
Trainer
(
return
Trainer
(
model
,
model
,
args
,
args
,
...
@@ -973,7 +981,7 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -973,7 +981,7 @@ class TrainerIntegrationTest(unittest.TestCase):
@
require_torch
@
require_torch
@
require_optuna
@
require_optuna
class
TrainerHyperParameterIntegrationTest
(
unittest
.
TestCase
):
class
TrainerHyperParameter
Optuna
IntegrationTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
args
=
TrainingArguments
(
"."
)
args
=
TrainingArguments
(
"."
)
self
.
n_epochs
=
args
.
num_train_epochs
self
.
n_epochs
=
args
.
num_train_epochs
...
@@ -1014,3 +1022,49 @@ class TrainerHyperParameterIntegrationTest(unittest.TestCase):
...
@@ -1014,3 +1022,49 @@ class TrainerHyperParameterIntegrationTest(unittest.TestCase):
model_init
=
model_init
,
model_init
=
model_init
,
)
)
trainer
.
hyperparameter_search
(
direction
=
"minimize"
,
hp_space
=
hp_space
,
hp_name
=
hp_name
,
n_trials
=
4
)
trainer
.
hyperparameter_search
(
direction
=
"minimize"
,
hp_space
=
hp_space
,
hp_name
=
hp_name
,
n_trials
=
4
)
@
require_torch
@
require_ray
class
TrainerHyperParameterRayIntegrationTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
args
=
TrainingArguments
(
"."
)
self
.
n_epochs
=
args
.
num_train_epochs
self
.
batch_size
=
args
.
train_batch_size
def
test_hyperparameter_search
(
self
):
class
MyTrialShortNamer
(
TrialShortNamer
):
DEFAULTS
=
{
"a"
:
0
,
"b"
:
0
}
def
hp_space
(
trial
):
from
ray
import
tune
return
{
"a"
:
tune
.
randint
(
-
4
,
4
),
"b"
:
tune
.
randint
(
-
4
,
4
),
}
def
model_init
(
config
):
model_config
=
RegressionModelConfig
(
a
=
config
[
"a"
],
b
=
config
[
"b"
],
double_output
=
False
)
return
RegressionPreTrainedModel
(
model_config
)
def
hp_name
(
params
):
return
MyTrialShortNamer
.
shortname
(
params
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
learning_rate
=
0.1
,
logging_steps
=
1
,
evaluation_strategy
=
EvaluationStrategy
.
EPOCH
,
num_train_epochs
=
4
,
disable_tqdm
=
True
,
load_best_model_at_end
=
True
,
logging_dir
=
"runs"
,
run_name
=
"test"
,
model_init
=
model_init
,
)
trainer
.
hyperparameter_search
(
direction
=
"minimize"
,
hp_space
=
hp_space
,
hp_name
=
hp_name
,
backend
=
"ray"
,
n_trials
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment