"vscode:/vscode.git/clone" did not exist on "b2a0170ae34e202dbe40d7e58b3dc47782bb9c59"
Commit f041db19 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336960641
parent a027c8a6
......@@ -33,15 +33,17 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params, logging_dir: str = None):
def __init__(self, params, logging_dir: str = None, name: str = None):
"""Task initialization.
Args:
params: the task configuration instance, which can be any of
dataclass, ConfigDict, namedtuple, etc.
params: the task configuration instance, which can be any of dataclass,
ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
name: the task name.
"""
super().__init__(name=name)
self._task_config = params
self._logging_dir = logging_dir
......
......@@ -46,8 +46,8 @@ class MockTaskConfig(cfg.TaskConfig):
class MockTask(base_task.Task):
"""Mock task object for testing."""
def __init__(self, params=None, logging_dir=None):
super().__init__(params=params, logging_dir=logging_dir)
def __init__(self, params=None, logging_dir=None, name=None):
super().__init__(params=params, logging_dir=logging_dir, name=name)
def build_model(self, *arg, **kwargs):
inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment