Commit 0268d1d3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413431736
parent 6af14369
...@@ -232,8 +232,9 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -232,8 +232,9 @@ class EncoderConfig(hyperparams.OneOfConfig):
kernel: KernelEncoderConfig = KernelEncoderConfig() kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig() reuse: ReuseEncoderConfig = ReuseEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
@gin.configurable @gin.configurable
...@@ -290,6 +291,16 @@ def build_encoder(config: EncoderConfig, ...@@ -290,6 +291,16 @@ def build_encoder(config: EncoderConfig,
dict_outputs=True) dict_outputs=True)
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_type == "any":
encoder = encoder_cfg.BUILDER(encoder_cfg)
if not isinstance(encoder,
(tf.Module, tf.keras.Model, tf.keras.layers.Layer)):
raise ValueError("The BUILDER returns an unexpected instance. The "
"`build_encoder` should returns a tf.Module, "
"tf.keras.Model or tf.keras.layers.Layer. However, "
f"we get {encoder.__class__}")
return encoder
if encoder_type == "mobilebert": if encoder_type == "mobilebert":
return networks.MobileBERTEncoder( return networks.MobileBERTEncoder(
word_vocab_size=encoder_cfg.word_vocab_size, word_vocab_size=encoder_cfg.word_vocab_size,
...@@ -465,40 +476,6 @@ def build_encoder(config: EncoderConfig, ...@@ -465,40 +476,6 @@ def build_encoder(config: EncoderConfig,
initializer=tf.keras.initializers.RandomNormal( initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range)) stddev=encoder_cfg.initializer_range))
if encoder_type == "teams":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
embedding_width=encoder_cfg.embedding_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "reuse": if encoder_type == "reuse":
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size, vocab_size=encoder_cfg.vocab_size,
......
...@@ -19,6 +19,8 @@ import tensorflow as tf ...@@ -19,6 +19,8 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import networks
from official.nlp.projects.teams import teams
class EncodersTest(tf.test.TestCase): class EncodersTest(tf.test.TestCase):
...@@ -37,6 +39,14 @@ class EncodersTest(tf.test.TestCase): ...@@ -37,6 +39,14 @@ class EncodersTest(tf.test.TestCase):
status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path) status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path)
status.assert_consumed() status.assert_consumed()
def test_build_teams(self):
config = encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))
encoder = encoders.build_encoder(config)
self.assertIsInstance(encoder, networks.EncoderScaffold)
self.assertIsInstance(encoder.embedding_network,
networks.PackedSequenceEmbedding)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
task: task:
model: model:
encoder: encoder:
teams: any: # Teams encoder.
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 768 embedding_size: 768
...@@ -14,4 +14,4 @@ task: ...@@ -14,4 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams type: any
task: task:
model: model:
encoder: encoder:
teams: any: # Teams encoder.
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 128 embedding_size: 128
...@@ -14,4 +14,4 @@ task: ...@@ -14,4 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams type: any
...@@ -45,13 +45,17 @@ class TeamsPretrainerConfig(base_config.Config): ...@@ -45,13 +45,17 @@ class TeamsPretrainerConfig(base_config.Config):
num_discriminator_task_agnostic_layers: int = 11 num_discriminator_task_agnostic_layers: int = 11
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig() class TeamsEncoderConfig(encoders.BertEncoderConfig):
pass
@gin.configurable @gin.configurable
def get_encoder(bert_config, embedding_network=None, hidden_layers=None): @base_config.bind(TeamsEncoderConfig)
def get_encoder(bert_config: TeamsEncoderConfig,
embedding_network=None,
hidden_layers=None):
"""Gets a 'EncoderScaffold' object. """Gets a 'EncoderScaffold' object.
Args: Args:
...@@ -98,4 +102,4 @@ def get_encoder(bert_config, embedding_network=None, hidden_layers=None): ...@@ -98,4 +102,4 @@ def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
dict_outputs=True) dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments. # Relies on gin configuration to define the Transformer encoder arguments.
return networks.encoder_scaffold.EncoderScaffold(**kwargs) return networks.EncoderScaffold(**kwargs)
...@@ -16,12 +16,18 @@ ...@@ -16,12 +16,18 @@
# pylint: disable=g-doc-return-or-yield,line-too-long # pylint: disable=g-doc-return-or-yield,line-too-long
"""TEAMS experiments.""" """TEAMS experiments."""
import dataclasses import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
from official.nlp.data import question_answering_dataloader
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_task from official.nlp.projects.teams import teams_task
from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction
AdamWeightDecay = optimization.AdamWeightDecayConfig AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig PolynomialLr = optimization.PolynomialLrConfig
...@@ -62,3 +68,42 @@ def teams_pretrain() -> cfg.ExperimentConfig: ...@@ -62,3 +68,42 @@ def teams_pretrain() -> cfg.ExperimentConfig:
"task.validation_data.is_training != None" "task.validation_data.is_training != None"
]) ])
return config return config
@exp_factory.register_config_factory("teams/sentence_prediction")
def teams_sentence_prediction() -> cfg.ExperimentConfig:
r"""Teams GLUE."""
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
@exp_factory.register_config_factory("teams/squad")
def teams_squad() -> cfg.ExperimentConfig:
"""Teams Squad V1/V2."""
config = cfg.ExperimentConfig(
task=question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(
encoder=encoders.EncoderConfig(
type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
train_data=question_answering_dataloader.QADataConfig(),
validation_data=question_answering_dataloader.QADataConfig()),
trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment