Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5eb294f8
Commit
5eb294f8
authored
Jul 30, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jul 30, 2020
Browse files
Internal change
PiperOrigin-RevId: 324140487
parent
a62c2bfc
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
115 additions
and
66 deletions
+115
-66
official/core/base_task.py
official/core/base_task.py
+0
-47
official/core/exp_factory.py
official/core/exp_factory.py
+37
-0
official/core/task_factory.py
official/core/task_factory.py
+68
-0
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+0
-14
official/nlp/tasks/electra_task.py
official/nlp/tasks/electra_task.py
+2
-1
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+2
-1
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+2
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+2
-1
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+2
-1
No files found.
official/core/base_task.py
View file @
5eb294f8
...
@@ -23,7 +23,6 @@ import six
...
@@ -23,7 +23,6 @@ import six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.utils
import
registry
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
...
@@ -295,49 +294,3 @@ class Task(tf.Module):
...
@@ -295,49 +294,3 @@ class Task(tf.Module):
"""Optional reduce of aggregated logs over validation steps."""
"""Optional reduce of aggregated logs over validation steps."""
return
{}
return
{}
_REGISTERED_TASK_CLS
=
{}
# TODO(b/158268740): Move these outside the base class file.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
register_task_cls
(
task_config_cls
):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
get_task_cls
(
task_config_cls
):
task_cls
=
registry
.
lookup
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
return
task_cls
official/core/exp_factory.py
0 → 100644
View file @
5eb294f8
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experiment factory methods."""
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.utils
import
registry
_REGISTERED_CONFIGS
=
{}
def
register_config_factory
(
name
):
"""Register ExperimentConfig factory method."""
return
registry
.
register
(
_REGISTERED_CONFIGS
,
name
)
def
get_exp_config_creater
(
exp_name
:
str
):
"""Looks up ExperimentConfig factory methods."""
exp_creater
=
registry
.
lookup
(
_REGISTERED_CONFIGS
,
exp_name
)
return
exp_creater
def
get_exp_config
(
exp_name
:
str
)
->
cfg
.
ExperimentConfig
:
return
get_exp_config_creater
(
exp_name
)()
official/core/task_factory.py
0 → 100644
View file @
5eb294f8
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks."""
from
official.utils
import
registry
_REGISTERED_TASK_CLS
=
{}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
register_task_cls
(
task_config_cls
):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
def
get_task
(
task_config
,
**
kwargs
):
"""Creates a Task (of suitable subclass type) from task_config."""
return
get_task_cls
(
task_config
.
__class__
)(
task_config
,
**
kwargs
)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
get_task_cls
(
task_config_cls
):
task_cls
=
registry
.
lookup
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
return
task_cls
official/modeling/hyperparams/config_definitions.py
View file @
5eb294f8
...
@@ -21,7 +21,6 @@ import dataclasses
...
@@ -21,7 +21,6 @@ import dataclasses
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.utils
import
registry
OptimizationConfig
=
optimization_config
.
OptimizationConfig
OptimizationConfig
=
optimization_config
.
OptimizationConfig
...
@@ -219,16 +218,3 @@ class ExperimentConfig(base_config.Config):
...
@@ -219,16 +218,3 @@ class ExperimentConfig(base_config.Config):
trainer
:
TrainerConfig
=
TrainerConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
_REGISTERED_CONFIGS
=
{}
def
register_config_factory
(
name
):
"""Register ExperimentConfig factory method."""
return
registry
.
register
(
_REGISTERED_CONFIGS
,
name
)
def
get_exp_config_creater
(
exp_name
:
str
):
"""Looks up ExperimentConfig factory methods."""
exp_creater
=
registry
.
lookup
(
_REGISTERED_CONFIGS
,
exp_name
)
return
exp_creater
official/nlp/tasks/electra_task.py
View file @
5eb294f8
...
@@ -18,6 +18,7 @@ import dataclasses
...
@@ -18,6 +18,7 @@ import dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
electra
...
@@ -39,7 +40,7 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
...
@@ -39,7 +40,7 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
b
as
e_task
.
register_task_cls
(
ELECTRAPretrainConfig
)
@
t
as
k_factory
.
register_task_cls
(
ELECTRAPretrainConfig
)
class
ELECTRAPretrainTask
(
base_task
.
Task
):
class
ELECTRAPretrainTask
(
base_task
.
Task
):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
...
...
official/nlp/tasks/masked_lm.py
View file @
5eb294f8
...
@@ -18,6 +18,7 @@ import dataclasses
...
@@ -18,6 +18,7 @@ import dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
data_loader_factory
...
@@ -34,7 +35,7 @@ class MaskedLMConfig(cfg.TaskConfig):
...
@@ -34,7 +35,7 @@ class MaskedLMConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
b
as
e_task
.
register_task_cls
(
MaskedLMConfig
)
@
t
as
k_factory
.
register_task_cls
(
MaskedLMConfig
)
class
MaskedLMTask
(
base_task
.
Task
):
class
MaskedLMTask
(
base_task
.
Task
):
"""Mock task object for testing."""
"""Mock task object for testing."""
...
...
official/nlp/tasks/question_answering.py
View file @
5eb294f8
...
@@ -23,6 +23,7 @@ import tensorflow as tf
...
@@ -23,6 +23,7 @@ import tensorflow as tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v1_1
...
@@ -57,7 +58,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
...
@@ -57,7 +58,7 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
b
as
e_task
.
register_task_cls
(
QuestionAnsweringConfig
)
@
t
as
k_factory
.
register_task_cls
(
QuestionAnsweringConfig
)
class
QuestionAnsweringTask
(
base_task
.
Task
):
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering."""
"""Task object for question answering."""
...
...
official/nlp/tasks/sentence_prediction.py
View file @
5eb294f8
...
@@ -26,6 +26,7 @@ import tensorflow as tf
...
@@ -26,6 +26,7 @@ import tensorflow as tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
...
@@ -62,7 +63,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
...
@@ -62,7 +63,7 @@ class SentencePredictionConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
b
as
e_task
.
register_task_cls
(
SentencePredictionConfig
)
@
t
as
k_factory
.
register_task_cls
(
SentencePredictionConfig
)
class
SentencePredictionTask
(
base_task
.
Task
):
class
SentencePredictionTask
(
base_task
.
Task
):
"""Task object for sentence_prediction."""
"""Task object for sentence_prediction."""
...
...
official/nlp/tasks/tagging.py
View file @
5eb294f8
...
@@ -25,6 +25,7 @@ import tensorflow as tf
...
@@ -25,6 +25,7 @@ import tensorflow as tf
import
tensorflow_hub
as
hub
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
...
@@ -80,7 +81,7 @@ def _masked_labels_and_weights(y_true):
...
@@ -80,7 +81,7 @@ def _masked_labels_and_weights(y_true):
return
masked_y_true
,
tf
.
cast
(
mask
,
tf
.
float32
)
return
masked_y_true
,
tf
.
cast
(
mask
,
tf
.
float32
)
@
b
as
e_task
.
register_task_cls
(
TaggingConfig
)
@
t
as
k_factory
.
register_task_cls
(
TaggingConfig
)
class
TaggingTask
(
base_task
.
Task
):
class
TaggingTask
(
base_task
.
Task
):
"""Task object for tagging (e.g., NER or POS)."""
"""Task object for tagging (e.g., NER or POS)."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment