Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9a8ea6a0
Commit
9a8ea6a0
authored
Sep 24, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 24, 2021
Browse files
Internal change
PiperOrigin-RevId: 398790392
parent
3ea57553
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
7 deletions
+74
-7
official/core/task_factory.py
official/core/task_factory.py
+3
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+43
-5
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+25
-0
official/utils/testing/mock_task.py
official/utils/testing/mock_task.py
+3
-2
No files found.
official/core/task_factory.py
View file @
9a8ea6a0
...
...
@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls):
def
get_task
(
task_config
,
**
kwargs
):
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if
task_config
.
BUILDER
is
not
None
:
return
task_config
.
BUILDER
(
task_config
,
**
kwargs
)
return
get_task_cls
(
task_config
.
__class__
)(
task_config
,
**
kwargs
)
...
...
official/modeling/hyperparams/base_config.py
View file @
9a8ea6a0
...
...
@@ -13,18 +13,46 @@
# limitations under the License.
"""Base configurations to standardize experiments."""
import
copy
import
dataclasses
import
functools
import
inspect
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
dataclasses
from
absl
import
logging
import
tensorflow
as
tf
import
yaml
from
official.modeling.hyperparams
import
params_dict
_BOUND
=
set
()
def
bind
(
config_cls
):
"""Bind a class to config cls."""
if
not
inspect
.
isclass
(
config_cls
):
raise
ValueError
(
'The bind decorator is supposed to apply on the class '
f
'attribute. Received
{
config_cls
}
, not a class.'
)
def
decorator
(
builder
):
if
config_cls
in
_BOUND
:
raise
ValueError
(
'Inside a program, we should not bind the config with a'
' class twice.'
)
if
inspect
.
isclass
(
builder
):
config_cls
.
_BUILDER
=
builder
# pylint: disable=protected-access
elif
inspect
.
isfunction
(
builder
):
def
_wrapper
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
builder
(
*
args
,
**
kwargs
)
config_cls
.
_BUILDER
=
_wrapper
# pylint: disable=protected-access
else
:
raise
ValueError
(
f
'The `BUILDER` type is not supported:
{
builder
}
'
)
_BOUND
.
add
(
config_cls
)
return
builder
return
decorator
@
dataclasses
.
dataclass
class
Config
(
params_dict
.
ParamsDict
):
...
...
@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict):
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER
=
None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
# It's safe to add set, frozenset and other collections here.
...
...
@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict):
default_params
=
default_params
,
restrictions
=
restrictions
)
@
property
def
BUILDER
(
self
):
return
self
.
_BUILDER
@
classmethod
def
_isvalidsequence
(
cls
,
v
):
"""Check if the input values are valid sequences.
...
...
@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict):
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
def
__setattr__
(
self
,
k
,
v
):
if
k
==
'BUILDER'
or
k
==
'_BUILDER'
:
raise
AttributeError
(
'`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.'
)
if
k
not
in
self
.
RESERVED_ATTR
:
if
getattr
(
self
,
'_locked'
,
False
):
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
...
...
@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict):
attributes
=
list
(
cls
.
__annotations__
.
keys
())
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
.
update
(
kwargs
)
return
cls
(
default_params
)
return
cls
(
default_params
=
default_params
)
official/modeling/hyperparams/base_config_test.py
View file @
9a8ea6a0
...
...
@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
def
test_cls
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
BUILDER
=
DumpConfig2
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
_BUILDER
=
DumpConfig2
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
params
=
DumpConfig1
()
self
.
assertEqual
(
params
.
BUILDER
,
DumpConfig2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'Inside a program, we should not bind'
):
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
def
_test
():
return
'test'
base_config
.
bind
(
DumpConfig2
)(
_test
)
params
=
DumpConfig2
()
self
.
assertEqual
(
params
.
BUILDER
(),
'test'
)
def
test_nested_config_types
(
self
):
config
=
DumpConfig3
()
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
...
...
official/utils/testing/mock_task.py
View file @
9a8ea6a0
...
...
@@ -15,13 +15,14 @@
"""Mock task for testing."""
import
dataclasses
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.
core
import
t
as
k_factory
from
official.
modeling.hyperparams
import
b
as
e_config
class
MockModel
(
tf
.
keras
.
Model
):
...
...
@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig):
pass
@
t
as
k_factory
.
register_task_cls
(
MockTaskConfig
)
@
b
as
e_config
.
bind
(
MockTaskConfig
)
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment