Commit 728b3dbf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413431736
parent 9b559ad1
......@@ -232,8 +232,9 @@ class EncoderConfig(hyperparams.OneOfConfig):
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
@gin.configurable
......@@ -290,6 +291,16 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True)
return encoder_cls(**kwargs)
if encoder_type == "any":
encoder = encoder_cfg.BUILDER(encoder_cfg)
if not isinstance(encoder,
(tf.Module, tf.keras.Model, tf.keras.layers.Layer)):
raise ValueError("The BUILDER returns an unexpected instance. The "
"`build_encoder` should returns a tf.Module, "
"tf.keras.Model or tf.keras.layers.Layer. However, "
f"we get {encoder.__class__}")
return encoder
if encoder_type == "mobilebert":
return networks.MobileBERTEncoder(
word_vocab_size=encoder_cfg.word_vocab_size,
......@@ -465,40 +476,6 @@ def build_encoder(config: EncoderConfig,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
if encoder_type == "teams":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
embedding_width=encoder_cfg.embedding_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "reuse":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
......
......@@ -19,6 +19,8 @@ import tensorflow as tf
from official.modeling import hyperparams
from official.nlp.configs import encoders
from official.nlp.modeling import networks
from official.nlp.projects.teams import teams
class EncodersTest(tf.test.TestCase):
......@@ -37,6 +39,14 @@ class EncodersTest(tf.test.TestCase):
status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path)
status.assert_consumed()
def test_build_teams(self):
config = encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))
encoder = encoders.build_encoder(config)
self.assertIsInstance(encoder, networks.EncoderScaffold)
self.assertIsInstance(encoder.embedding_network,
networks.PackedSequenceEmbedding)
if __name__ == "__main__":
tf.test.main()
task:
model:
encoder:
teams:
any: # Teams encoder.
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
......@@ -14,4 +14,4 @@ task:
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
type: teams
type: any
task:
model:
encoder:
teams:
any: # Teams encoder.
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 128
......@@ -14,4 +14,4 @@ task:
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
type: teams
type: any
......@@ -45,13 +45,17 @@ class TeamsPretrainerConfig(base_config.Config):
num_discriminator_task_agnostic_layers: int = 11
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
class TeamsEncoderConfig(encoders.BertEncoderConfig):
pass
@gin.configurable
def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
@base_config.bind(TeamsEncoderConfig)
def get_encoder(bert_config: TeamsEncoderConfig,
embedding_network=None,
hidden_layers=None):
"""Gets a 'EncoderScaffold' object.
Args:
......@@ -98,4 +102,4 @@ def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments.
return networks.encoder_scaffold.EncoderScaffold(**kwargs)
return networks.EncoderScaffold(**kwargs)
......@@ -16,12 +16,18 @@
# pylint: disable=g-doc-return-or-yield,line-too-long
"""TEAMS experiments."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.data import question_answering_dataloader
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_task
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig
......@@ -62,3 +68,42 @@ def teams_pretrain() -> cfg.ExperimentConfig:
"task.validation_data.is_training != None"
])
return config
@exp_factory.register_config_factory("teams/sentence_prediction")
def teams_sentence_prediction() -> cfg.ExperimentConfig:
r"""Teams GLUE."""
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
@exp_factory.register_config_factory("teams/squad")
def teams_squad() -> cfg.ExperimentConfig:
"""Teams Squad V1/V2."""
config = cfg.ExperimentConfig(
task=question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=question_answering_dataloader.QADataConfig(),
validation_data=question_answering_dataloader.QADataConfig()),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment