Commit 61f2bad4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398790392
parent c72ec9d3
......@@ -57,6 +57,9 @@ def register_task_cls(task_config_cls):
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if task_config.BUILDER is not None:
return task_config.BUILDER(task_config, **kwargs)
return get_task_cls(task_config.__class__)(task_config, **kwargs)
......
......@@ -13,18 +13,46 @@
# limitations under the License.
"""Base configurations to standardize experiments."""
import copy
import dataclasses
import functools
import inspect
from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses
from absl import logging
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
_BOUND = set()
def bind(config_cls):
"""Bind a class to config cls."""
if not inspect.isclass(config_cls):
raise ValueError('The bind decorator is supposed to apply on the class '
f'attribute. Received {config_cls}, not a class.')
def decorator(builder):
if config_cls in _BOUND:
raise ValueError('Inside a program, we should not bind the config with a'
' class twice.')
if inspect.isclass(builder):
config_cls._BUILDER = builder # pylint: disable=protected-access
elif inspect.isfunction(builder):
def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument
return builder(*args, **kwargs)
config_cls._BUILDER = _wrapper # pylint: disable=protected-access
else:
raise ValueError(f'The `BUILDER` type is not supported: {builder}')
_BOUND.add(config_cls)
return builder
return decorator
@dataclasses.dataclass
class Config(params_dict.ParamsDict):
......@@ -40,7 +68,8 @@ class Config(params_dict.ParamsDict):
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER = None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here.
......@@ -54,6 +83,10 @@ class Config(params_dict.ParamsDict):
default_params=default_params,
restrictions=restrictions)
@property
def BUILDER(self):
return self._BUILDER
@classmethod
def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences.
......@@ -188,6 +221,11 @@ class Config(params_dict.ParamsDict):
self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v):
if k == 'BUILDER' or k == '_BUILDER':
raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.')
if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.')
......@@ -265,4 +303,4 @@ class Config(params_dict.ParamsDict):
attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs)
return cls(default_params)
return cls(default_params=default_params)
......@@ -91,6 +91,31 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
with self.assertRaises(AttributeError):
_ = params.a
def test_cls(self):
params = base_config.Config()
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params.BUILDER = DumpConfig2
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params._BUILDER = DumpConfig2
base_config.bind(DumpConfig1)(DumpConfig2)
params = DumpConfig1()
self.assertEqual(params.BUILDER, DumpConfig2)
with self.assertRaisesRegex(ValueError,
'Inside a program, we should not bind'):
base_config.bind(DumpConfig1)(DumpConfig2)
def _test():
return 'test'
base_config.bind(DumpConfig2)(_test)
params = DumpConfig2()
self.assertEqual(params.BUILDER(), 'test')
def test_nested_config_types(self):
config = DumpConfig3()
self.assertIsInstance(config.e, DumpConfig1)
......
......@@ -15,13 +15,14 @@
"""Mock task for testing."""
import dataclasses
import numpy as np
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.core import task_factory
from official.modeling.hyperparams import base_config
class MockModel(tf.keras.Model):
......@@ -41,7 +42,7 @@ class MockTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(MockTaskConfig)
@base_config.bind(MockTaskConfig)
class MockTask(base_task.Task):
"""Mock task object for testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment