Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
61f2bad4
Commit
61f2bad4
authored
Sep 24, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 24, 2021
Browse files
Internal change
PiperOrigin-RevId: 398790392
parent
c72ec9d3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
7 deletions
+74
-7
official/core/task_factory.py
official/core/task_factory.py
+3
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+43
-5
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+25
-0
official/utils/testing/mock_task.py
official/utils/testing/mock_task.py
+3
-2
No files found.
official/core/task_factory.py
View file @
61f2bad4
...
@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls):
...
@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls):
def
get_task
(
task_config
,
**
kwargs
):
def
get_task
(
task_config
,
**
kwargs
):
"""Creates a Task (of suitable subclass type) from task_config."""
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if
task_config
.
BUILDER
is
not
None
:
return
task_config
.
BUILDER
(
task_config
,
**
kwargs
)
return
get_task_cls
(
task_config
.
__class__
)(
task_config
,
**
kwargs
)
return
get_task_cls
(
task_config
.
__class__
)(
task_config
,
**
kwargs
)
...
...
official/modeling/hyperparams/base_config.py
View file @
61f2bad4
...
@@ -13,18 +13,46 @@
...
@@ -13,18 +13,46 @@
# limitations under the License.
# limitations under the License.
"""Base configurations to standardize experiments."""
"""Base configurations to standardize experiments."""
import
copy
import
copy
import
dataclasses
import
functools
import
functools
import
inspect
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
dataclasses
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
yaml
import
yaml
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
_BOUND
=
set
()
def
bind
(
config_cls
):
"""Bind a class to config cls."""
if
not
inspect
.
isclass
(
config_cls
):
raise
ValueError
(
'The bind decorator is supposed to apply on the class '
f
'attribute. Received
{
config_cls
}
, not a class.'
)
def
decorator
(
builder
):
if
config_cls
in
_BOUND
:
raise
ValueError
(
'Inside a program, we should not bind the config with a'
' class twice.'
)
if
inspect
.
isclass
(
builder
):
config_cls
.
_BUILDER
=
builder
# pylint: disable=protected-access
elif
inspect
.
isfunction
(
builder
):
def
_wrapper
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
builder
(
*
args
,
**
kwargs
)
config_cls
.
_BUILDER
=
_wrapper
# pylint: disable=protected-access
else
:
raise
ValueError
(
f
'The `BUILDER` type is not supported:
{
builder
}
'
)
_BOUND
.
add
(
config_cls
)
return
builder
return
decorator
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Config
(
params_dict
.
ParamsDict
):
class
Config
(
params_dict
.
ParamsDict
):
...
@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict):
...
@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict):
If you define/annotate some field as Dict, the field will convert to a
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
`Config` instance and lose the dictionary type.
"""
"""
# The class or method to bind with the params class.
_BUILDER
=
None
# It's safe to add bytes and other immutable types here.
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
# It's safe to add set, frozenset and other collections here.
# It's safe to add set, frozenset and other collections here.
...
@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict):
...
@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict):
default_params
=
default_params
,
default_params
=
default_params
,
restrictions
=
restrictions
)
restrictions
=
restrictions
)
@
property
def
BUILDER
(
self
):
return
self
.
_BUILDER
@
classmethod
@
classmethod
def
_isvalidsequence
(
cls
,
v
):
def
_isvalidsequence
(
cls
,
v
):
"""Check if the input values are valid sequences.
"""Check if the input values are valid sequences.
...
@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict):
...
@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict):
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
def
__setattr__
(
self
,
k
,
v
):
def
__setattr__
(
self
,
k
,
v
):
if
k
==
'BUILDER'
or
k
==
'_BUILDER'
:
raise
AttributeError
(
'`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.'
)
if
k
not
in
self
.
RESERVED_ATTR
:
if
k
not
in
self
.
RESERVED_ATTR
:
if
getattr
(
self
,
'_locked'
,
False
):
if
getattr
(
self
,
'_locked'
,
False
):
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
...
@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict):
...
@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict):
attributes
=
list
(
cls
.
__annotations__
.
keys
())
attributes
=
list
(
cls
.
__annotations__
.
keys
())
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
.
update
(
kwargs
)
default_params
.
update
(
kwargs
)
return
cls
(
default_params
)
return
cls
(
default_params
=
default_params
)
official/modeling/hyperparams/base_config_test.py
View file @
61f2bad4
...
@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
with
self
.
assertRaises
(
AttributeError
):
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
_
=
params
.
a
def
test_cls
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
BUILDER
=
DumpConfig2
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
_BUILDER
=
DumpConfig2
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
params
=
DumpConfig1
()
self
.
assertEqual
(
params
.
BUILDER
,
DumpConfig2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'Inside a program, we should not bind'
):
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
def
_test
():
return
'test'
base_config
.
bind
(
DumpConfig2
)(
_test
)
params
=
DumpConfig2
()
self
.
assertEqual
(
params
.
BUILDER
(),
'test'
)
def
test_nested_config_types
(
self
):
def
test_nested_config_types
(
self
):
config
=
DumpConfig3
()
config
=
DumpConfig3
()
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
...
...
official/utils/testing/mock_task.py
View file @
61f2bad4
...
@@ -15,13 +15,14 @@
...
@@ -15,13 +15,14 @@
"""Mock task for testing."""
"""Mock task for testing."""
import
dataclasses
import
dataclasses
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.
core
import
t
as
k_factory
from
official.
modeling.hyperparams
import
b
as
e_config
class
MockModel
(
tf
.
keras
.
Model
):
class
MockModel
(
tf
.
keras
.
Model
):
...
@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig):
...
@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig):
pass
pass
@
t
as
k_factory
.
register_task_cls
(
MockTaskConfig
)
@
b
as
e_config
.
bind
(
MockTaskConfig
)
class
MockTask
(
base_task
.
Task
):
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
"""Mock task object for testing."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment