"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7c5fce187d48eca50641e541f2c5bc3f3ef12689"
Commit 61f2bad4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398790392
parent c72ec9d3
...@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls): ...@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls):
def get_task(task_config, **kwargs): def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config.""" """Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if task_config.BUILDER is not None:
return task_config.BUILDER(task_config, **kwargs)
return get_task_cls(task_config.__class__)(task_config, **kwargs) return get_task_cls(task_config.__class__)(task_config, **kwargs)
......
...@@ -13,18 +13,46 @@ ...@@ -13,18 +13,46 @@
# limitations under the License. # limitations under the License.
"""Base configurations to standardize experiments.""" """Base configurations to standardize experiments."""
import copy import copy
import dataclasses
import functools import functools
import inspect
from typing import Any, List, Mapping, Optional, Type from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses from absl import logging
import tensorflow as tf import tensorflow as tf
import yaml import yaml
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
_BOUND = set()
def bind(config_cls):
"""Bind a class to config cls."""
if not inspect.isclass(config_cls):
raise ValueError('The bind decorator is supposed to apply on the class '
f'attribute. Received {config_cls}, not a class.')
def decorator(builder):
if config_cls in _BOUND:
raise ValueError('Inside a program, we should not bind the config with a'
' class twice.')
if inspect.isclass(builder):
config_cls._BUILDER = builder # pylint: disable=protected-access
elif inspect.isfunction(builder):
def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument
return builder(*args, **kwargs)
config_cls._BUILDER = _wrapper # pylint: disable=protected-access
else:
raise ValueError(f'The `BUILDER` type is not supported: {builder}')
_BOUND.add(config_cls)
return builder
return decorator
@dataclasses.dataclass @dataclasses.dataclass
class Config(params_dict.ParamsDict): class Config(params_dict.ParamsDict):
...@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict): ...@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict):
If you define/annotate some field as Dict, the field will convert to a If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type. `Config` instance and lose the dictionary type.
""" """
# The class or method to bind with the params class.
_BUILDER = None
# It's safe to add bytes and other immutable types here. # It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None)) IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here. # It's safe to add set, frozenset and other collections here.
...@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict): ...@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict):
default_params=default_params, default_params=default_params,
restrictions=restrictions) restrictions=restrictions)
@property
def BUILDER(self):
return self._BUILDER
@classmethod @classmethod
def _isvalidsequence(cls, v): def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences. """Check if the input values are valid sequences.
...@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict): ...@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict):
self.__dict__[k] = self._import_config(v, subconfig_type) self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v): def __setattr__(self, k, v):
if k == 'BUILDER' or k == '_BUILDER':
raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.')
if k not in self.RESERVED_ATTR: if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False): if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.') raise ValueError('The Config has been locked. ' 'No change is allowed.')
...@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict): ...@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict):
attributes = list(cls.__annotations__.keys()) attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)} default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs) default_params.update(kwargs)
return cls(default_params) return cls(default_params=default_params)
...@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
_ = params.a _ = params.a
def test_cls(self):
params = base_config.Config()
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params.BUILDER = DumpConfig2
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params._BUILDER = DumpConfig2
base_config.bind(DumpConfig1)(DumpConfig2)
params = DumpConfig1()
self.assertEqual(params.BUILDER, DumpConfig2)
with self.assertRaisesRegex(ValueError,
'Inside a program, we should not bind'):
base_config.bind(DumpConfig1)(DumpConfig2)
def _test():
return 'test'
base_config.bind(DumpConfig2)(_test)
params = DumpConfig2()
self.assertEqual(params.BUILDER(), 'test')
def test_nested_config_types(self): def test_nested_config_types(self):
config = DumpConfig3() config = DumpConfig3()
self.assertIsInstance(config.e, DumpConfig1) self.assertIsInstance(config.e, DumpConfig1)
......
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
"""Mock task for testing.""" """Mock task for testing."""
import dataclasses import dataclasses
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.core import task_factory from official.modeling.hyperparams import base_config
class MockModel(tf.keras.Model): class MockModel(tf.keras.Model):
...@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig): ...@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig):
pass pass
@task_factory.register_task_cls(MockTaskConfig) @base_config.bind(MockTaskConfig)
class MockTask(base_task.Task): class MockTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment