Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d57ba596
Commit
d57ba596
authored
Dec 18, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Dec 18, 2020
Browse files
Fix unit test.
PiperOrigin-RevId: 348228865
parent
c9c230d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
8 deletions
+71
-8
official/modeling/progressive/train_lib_test.py
official/modeling/progressive/train_lib_test.py
+71
-8
No files found.
official/modeling/progressive/train_lib_test.py
View file @
d57ba596
...
...
@@ -17,6 +17,8 @@ import os
from
absl
import
flags
from
absl.testing
import
parameterized
import
dataclasses
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
...
...
@@ -27,17 +29,83 @@ from official.common import registry_imports
# pylint: enable=unused-import
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.progressive
import
policies
from
official.modeling.progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
progressive_masked_lm
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
tfm_flags
.
define_flags
()
@
dataclasses
.
dataclass
class
ProgTaskConfig
(
cfg
.
TaskConfig
):
pass
@
task_factory
.
register_task_cls
(
ProgTaskConfig
)
class
ProgMockTask
(
policies
.
ProgressivePolicy
,
mock_task
.
MockTask
):
"""Progressive task for testing."""
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
str
=
None
):
mock_task
.
MockTask
.
__init__
(
self
,
params
=
params
,
logging_dir
=
logging_dir
)
policies
.
ProgressivePolicy
.
__init__
(
self
)
def
num_stages
(
self
):
return
2
def
num_steps
(
self
,
stage_id
):
return
2
if
stage_id
==
0
else
4
def
get_model
(
self
,
stage_id
,
old_model
=
None
):
del
stage_id
,
old_model
return
self
.
build_model
()
def
get_optimizer
(
self
,
stage_id
):
"""Build optimizer for each stage."""
params
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
0.01
,
'end_learning_rate'
:
0.0
,
'power'
:
1.0
,
'decay_steps'
:
10
,
},
},
'warmup'
:
{
'polynomial'
:
{
'power'
:
1
,
'warmup_steps'
:
2
,
},
'type'
:
'polynomial'
,
}
})
opt_factory
=
optimization
.
OptimizerFactory
(
params
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
return
optimizer
def
get_train_dataset
(
self
,
stage_id
):
del
stage_id
strategy
=
tf
.
distribute
.
get_strategy
()
return
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
None
)
def
get_eval_dataset
(
self
,
stage_id
):
del
stage_id
strategy
=
tf
.
distribute
.
get_strategy
()
return
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
None
)
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -76,12 +144,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
cfg
.
ExperimentConfig
(
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(),
task
=
progressive_masked_lm
.
ProgMaskedLMConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
'dummy'
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
'dummy'
)))
task
=
ProgTaskConfig
())
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment