Commit 166f887c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make the base task as metaclass and decorate methods not implemented.

PiperOrigin-RevId: 316712226
parent 802488f1
......@@ -14,15 +14,18 @@
# limitations under the License.
# ==============================================================================
"""Defines the base task abstraction."""
import abc
import functools
from typing import Any, Callable, Optional
import six
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module):
"""A single-replica view of training procedure.
......@@ -54,14 +57,13 @@ class Task(tf.Module):
"""
pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
Returns:
A model instance.
"""
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self,
model: tf.keras.Model,
......@@ -98,6 +100,7 @@ class Task(tf.Module):
model.test_step = functools.partial(validation_step, model=model)
return model
@abc.abstractmethod
def build_inputs(self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
......@@ -112,7 +115,6 @@ class Task(tf.Module):
Returns:
A nested structure of per-replica input functions.
"""
pass
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment