"...csrc/ops/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0fd0f503a14ccf3f51d291c2f89721a7bc36c7b8"
Commit 166f887c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make the base task as metaclass and decorate methods not implemented.

PiperOrigin-RevId: 316712226
parent 802488f1
...@@ -14,15 +14,18 @@ ...@@ -14,15 +14,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Defines the base task abstraction.""" """Defines the base task abstraction."""
import abc
import functools import functools
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import six
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module): class Task(tf.Module):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -54,14 +57,13 @@ class Task(tf.Module): ...@@ -54,14 +57,13 @@ class Task(tf.Module):
""" """
pass pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates the model architecture.
Returns: Returns:
A model instance. A model instance.
""" """
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self, def compile_model(self,
model: tf.keras.Model, model: tf.keras.Model,
...@@ -98,6 +100,7 @@ class Task(tf.Module): ...@@ -98,6 +100,7 @@ class Task(tf.Module):
model.test_step = functools.partial(validation_step, model=model) model.test_step = functools.partial(validation_step, model=model)
return model return model
@abc.abstractmethod
def build_inputs(self, def build_inputs(self,
params: cfg.DataConfig, params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None): input_context: Optional[tf.distribute.InputContext] = None):
...@@ -112,7 +115,6 @@ class Task(tf.Module): ...@@ -112,7 +115,6 @@ class Task(tf.Module):
Returns: Returns:
A nested structure of per-replica input functions. A nested structure of per-replica input functions.
""" """
pass
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses. """Standard interface to compute losses.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment