Commit 5eb294f8 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 324140487
parent a62c2bfc
...@@ -23,7 +23,6 @@ import six ...@@ -23,7 +23,6 @@ import six
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
...@@ -295,49 +294,3 @@ class Task(tf.Module): ...@@ -295,49 +294,3 @@ class Task(tf.Module):
"""Optional reduce of aggregated logs over validation steps.""" """Optional reduce of aggregated logs over validation steps."""
return {} return {}
_REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experiment factory methods."""
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
return get_exp_config_creater(exp_name)()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks."""
from official.utils import registry
_REGISTERED_TASK_CLS = {}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
return get_task_cls(task_config.__class__)(task_config, **kwargs)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
...@@ -21,7 +21,6 @@ import dataclasses ...@@ -21,7 +21,6 @@ import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config from official.modeling.optimization.configs import optimization_config
from official.utils import registry
OptimizationConfig = optimization_config.OptimizationConfig OptimizationConfig = optimization_config.OptimizationConfig
...@@ -219,16 +218,3 @@ class ExperimentConfig(base_config.Config): ...@@ -219,16 +218,3 @@ class ExperimentConfig(base_config.Config):
trainer: TrainerConfig = TrainerConfig() trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig() runtime: RuntimeConfig = RuntimeConfig()
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
...@@ -18,6 +18,7 @@ import dataclasses ...@@ -18,6 +18,7 @@ import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import electra from official.nlp.configs import electra
...@@ -39,7 +40,7 @@ class ELECTRAPretrainConfig(cfg.TaskConfig): ...@@ -39,7 +40,7 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(ELECTRAPretrainConfig) @task_factory.register_task_cls(ELECTRAPretrainConfig)
class ELECTRAPretrainTask(base_task.Task): class ELECTRAPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection).""" """ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
......
...@@ -18,6 +18,7 @@ import dataclasses ...@@ -18,6 +18,7 @@ import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -34,7 +35,7 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -34,7 +35,7 @@ class MaskedLMConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(MaskedLMConfig) @task_factory.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
......
...@@ -23,6 +23,7 @@ import tensorflow as tf ...@@ -23,6 +23,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import squad_evaluate_v1_1 from official.nlp.bert import squad_evaluate_v1_1
...@@ -57,7 +58,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -57,7 +58,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(QuestionAnsweringConfig) @task_factory.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task): class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.""" """Task object for question answering."""
......
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
...@@ -62,7 +63,7 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -62,7 +63,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(SentencePredictionConfig) @task_factory.register_task_cls(SentencePredictionConfig)
class SentencePredictionTask(base_task.Task): class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction.""" """Task object for sentence_prediction."""
......
...@@ -25,6 +25,7 @@ import tensorflow as tf ...@@ -25,6 +25,7 @@ import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
...@@ -80,7 +81,7 @@ def _masked_labels_and_weights(y_true): ...@@ -80,7 +81,7 @@ def _masked_labels_and_weights(y_true):
return masked_y_true, tf.cast(mask, tf.float32) return masked_y_true, tf.cast(mask, tf.float32)
@base_task.register_task_cls(TaggingConfig) @task_factory.register_task_cls(TaggingConfig)
class TaggingTask(base_task.Task): class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS).""" """Task object for tagging (e.g., NER or POS)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment