Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dae3ba89
Commit
dae3ba89
authored
Dec 03, 2021
by
Yeqing Li
Committed by
A. Unique TensorFlower
Dec 03, 2021
Browse files
Internal change
PiperOrigin-RevId: 413972002
parent
7acb972a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
123 additions
and
41 deletions
+123
-41
official/core/train_utils.py
official/core/train_utils.py
+78
-40
official/core/train_utils_test.py
official/core/train_utils_test.py
+45
-1
No files found.
official/core/train_utils.py
View file @
dae3ba89
...
@@ -14,13 +14,13 @@
...
@@ -14,13 +14,13 @@
"""Training utils."""
"""Training utils."""
import
copy
import
copy
import
dataclasses
import
json
import
json
import
os
import
os
import
pprint
import
pprint
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -244,49 +244,87 @@ class ParseConfigOptions:
...
@@ -244,49 +244,87 @@ class ParseConfigOptions:
return
name
in
dataclasses
.
asdict
(
self
)
return
name
in
dataclasses
.
asdict
(
self
)
def
parse_configuration
(
flags_obj
,
lock_return
=
True
,
print_return
=
True
)
:
class
ExperimentParser
:
"""
Parses
Experiment
C
onfig from
f
lags
."""
"""
Constructs the
Experiment
c
onfig from
F
lags
or equivalent object.
if
flags_obj
.
experiment
is
None
:
Most of the cases, users only need to call the `parse()` function:
raise
ValueError
(
'The flag --experiment must be specified.'
)
```
builder = ExperimentParser(FLAGS)
params = builder.parse()
```
# 1. Get the default config from the registered experiment.
The advanced users can modify the flow by calling the parse_*() functions
params
=
exp_factory
.
get_exp_config
(
flags_obj
.
experiment
)
separately.
"""
def
__init__
(
self
,
flags_obj
):
self
.
_flags_obj
=
flags_obj
def
parse
(
self
):
"""Overrall process of constructing Experiment config."""
params
=
self
.
base_experiment
()
params
=
self
.
parse_config_file
(
params
)
params
=
self
.
parse_runtime
(
params
)
params
=
self
.
parse_data_service
(
params
)
params
=
self
.
parse_params_override
(
params
)
return
params
def
base_experiment
(
self
):
"""Get the base experiment config from --experiment field."""
if
self
.
_flags_obj
.
experiment
is
None
:
raise
ValueError
(
'The flag --experiment must be specified.'
)
return
exp_factory
.
get_exp_config
(
self
.
_flags_obj
.
experiment
)
# 2. Get the first level of override from `--config_file`.
def
parse_config_file
(
self
,
params
):
# `--config_file` is typically used as a template that specifies the common
"""Override the configs of params from the config_file."""
# override for a particular experiment.
for
config_file
in
self
.
_flags_obj
.
config_file
or
[]:
for
config_file
in
flags_obj
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
params
,
config_file
,
is_strict
=
True
)
return
params
# 3. Override the TPU address and tf.data service address.
def
parse_runtime
(
self
,
params
):
"""Override the runtime configs of params from flags."""
# Override the TPU address and tf.data service address.
params
.
override
({
params
.
override
({
'runtime'
:
{
'runtime'
:
{
'tpu'
:
flags_obj
.
tpu
,
'tpu'
:
self
.
_
flags_obj
.
tpu
,
},
},
})
})
if
(
'tf_data_service'
in
flags_obj
and
flags_obj
.
tf_data_service
and
return
params
def
parse_data_service
(
self
,
params
):
"""Override the data service configs of params from flags."""
if
(
'tf_data_service'
in
self
.
_flags_obj
and
self
.
_flags_obj
.
tf_data_service
and
isinstance
(
params
.
task
,
config_definitions
.
TaskConfig
)):
isinstance
(
params
.
task
,
config_definitions
.
TaskConfig
)):
params
.
override
({
params
.
override
({
'task'
:
{
'task'
:
{
'train_data'
:
{
'train_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
'tf_data_service_address'
:
self
.
_
flags_obj
.
tf_data_service
,
},
},
'validation_data'
:
{
'validation_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
'tf_data_service_address'
:
self
.
_
flags_obj
.
tf_data_service
,
}
}
}
}
})
})
return
params
# 4. Get the second level of override from `--params_override`.
def
parse_params_override
(
self
,
params
):
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
# then define different learning rates and pass it via `--params_override`.
if
flags_obj
.
params_override
:
if
self
.
_
flags_obj
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
=
hyperparams
.
override_params_dict
(
params
,
flags_obj
.
params_override
,
is_strict
=
True
)
params
,
self
.
_flags_obj
.
params_override
,
is_strict
=
True
)
return
params
def
parse_configuration
(
flags_obj
,
lock_return
=
True
,
print_return
=
True
):
"""Parses ExperimentConfig from flags."""
params
=
ExperimentParser
(
flags_obj
).
parse
()
params
.
validate
()
params
.
validate
()
if
lock_return
:
if
lock_return
:
...
...
official/core/train_utils_test.py
View file @
dae3ba89
...
@@ -13,14 +13,37 @@
...
@@ -13,14 +13,37 @@
# limitations under the License.
# limitations under the License.
"""Tests for official.core.train_utils."""
"""Tests for official.core.train_utils."""
import
os
import
os
import
pprint
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.core
import
test_utils
from
official.core
import
test_utils
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
hyperparams
@
exp_factory
.
register_config_factory
(
'foo'
)
def
foo
():
"""Multitask experiment for test."""
experiment_config
=
hyperparams
.
Config
(
default_params
=
{
'runtime'
:
{
'tpu'
:
'fake'
,
},
'task'
:
{
'model'
:
{
'model_id'
:
'bar'
,
},
},
'trainer'
:
{
'train_steps'
:
-
1
,
'validation_steps'
:
-
1
,
},
})
return
experiment_config
class
TrainUtilsTest
(
tf
.
test
.
TestCase
):
class
TrainUtilsTest
(
tf
.
test
.
TestCase
):
...
@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase):
...
@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase):
]
]
self
.
assertEqual
(
actual
,
expected
)
self
.
assertEqual
(
actual
,
expected
)
def
test_construct_experiment_from_flags
(
self
):
options
=
train_utils
.
ParseConfigOptions
(
experiment
=
'foo'
,
config_file
=
[],
tpu
=
'bar'
,
tf_data_service
=
''
,
params_override
=
'task.model.model_id=new,'
'trainer.train_steps=10,'
'trainer.validation_steps=11'
)
builder
=
train_utils
.
ExperimentParser
(
options
)
params_from_obj
=
builder
.
parse
()
params_from_func
=
train_utils
.
parse_configuration
(
options
)
pp
=
pprint
.
PrettyPrinter
()
self
.
assertEqual
(
pp
.
pformat
(
params_from_obj
.
as_dict
()),
pp
.
pformat
(
params_from_func
.
as_dict
()))
self
.
assertEqual
(
params_from_obj
.
runtime
.
tpu
,
'bar'
)
self
.
assertEqual
(
params_from_obj
.
task
.
model
.
model_id
,
'new'
)
self
.
assertEqual
(
params_from_obj
.
trainer
.
train_steps
,
10
)
self
.
assertEqual
(
params_from_obj
.
trainer
.
validation_steps
,
11
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment