Unverified Commit 09d9656f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents ac671306 49a5706c
......@@ -45,13 +45,17 @@ class TeamsPretrainerConfig(base_config.Config):
num_discriminator_task_agnostic_layers: int = 11
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
class TeamsEncoderConfig(encoders.BertEncoderConfig):
pass
@gin.configurable
def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
@base_config.bind(TeamsEncoderConfig)
def get_encoder(bert_config: TeamsEncoderConfig,
embedding_network=None,
hidden_layers=None):
"""Gets a 'EncoderScaffold' object.
Args:
......@@ -98,4 +102,4 @@ def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments.
return networks.encoder_scaffold.EncoderScaffold(**kwargs)
return networks.EncoderScaffold(**kwargs)
......@@ -16,12 +16,18 @@
# pylint: disable=g-doc-return-or-yield,line-too-long
"""TEAMS experiments."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.teams import teams_task
from official.nlp.data import question_answering_dataloader
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
from official.projects.teams import teams
from official.projects.teams import teams_task
AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig
......@@ -62,3 +68,42 @@ def teams_pretrain() -> cfg.ExperimentConfig:
"task.validation_data.is_training != None"
])
return config
@exp_factory.register_config_factory("teams/sentence_prediction")
def teams_sentence_prediction() -> cfg.ExperimentConfig:
r"""Teams GLUE."""
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
@exp_factory.register_config_factory("teams/squad")
def teams_squad() -> cfg.ExperimentConfig:
"""Teams Squad V1/V2."""
config = cfg.ExperimentConfig(
task=question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=question_answering_dataloader.QADataConfig(),
validation_data=question_answering_dataloader.QADataConfig()),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
......@@ -20,7 +20,7 @@ from tensorflow.python.keras import keras_parameterized # pylint: disable=g-dir
from official.modeling import activations
from official.nlp.modeling.networks import encoder_scaffold
from official.nlp.modeling.networks import packed_sequence_embedding
from official.nlp.projects.teams import teams_pretrainer
from official.projects.teams import teams_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
......@@ -23,8 +23,8 @@ from official.core import task_factory
from official.modeling import tf_utils
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_pretrainer
from official.projects.teams import teams
from official.projects.teams import teams_pretrainer
@dataclasses.dataclass
......
......@@ -19,8 +19,8 @@ import tensorflow as tf
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_task
from official.projects.teams import teams
from official.projects.teams import teams_task
class TeamsPretrainTaskTest(tf.test.TestCase, parameterized.TestCase):
......
......@@ -27,8 +27,8 @@ from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import models
from official.nlp.projects.example import classification_data_loader
from official.nlp.tasks import utils
from official.projects.text_classification_example import classification_data_loader
@dataclasses.dataclass
......
......@@ -18,8 +18,8 @@ import tensorflow as tf
from official.core import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.projects.example import classification_data_loader
from official.nlp.projects.example import classification_example
from official.projects.text_classification_example import classification_data_loader
from official.projects.text_classification_example import classification_example
class ClassificationExampleTest(tf.test.TestCase):
......
......@@ -23,7 +23,7 @@ from official.common import flags as tfm_flags
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.nlp.projects.example import classification_example
from official.projects.text_classification_example import classification_example
FLAGS = flags.FLAGS
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -23,7 +23,7 @@ import six
import tensorflow as tf
import tensorflow_datasets.public_api as tfds
from official.nlp.projects.triviaqa import preprocess
from official.projects.triviaqa import preprocess
_CITATION = """
@article{2017arXivtriviaqa,
......
......@@ -21,7 +21,7 @@ from absl import logging
import apache_beam as beam
import tensorflow_datasets as tfds
from official.nlp.projects.triviaqa import dataset # pylint: disable=unused-import
from official.projects.triviaqa import dataset # pylint: disable=unused-import
flags.DEFINE_integer('sequence_length', 4096, 'Max number of tokens.')
......
......@@ -20,7 +20,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from official.nlp.projects.triviaqa import evaluation
from official.projects.triviaqa import evaluation
flags.DEFINE_string('gold_path', None,
'Path to golden validation, i.e. wikipedia-dev.json.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment