Commit 164bab98 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393909662
parent cf2326ca
......@@ -21,74 +21,84 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass
class TeamsPretrainerConfig(base_config.Config):
"""Teams pretrainer configuration."""
num_masked_tokens: int = 76
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
# Candidate size for multi-word selection task, including the correct word.
candidate_size: int = 5
# Weight for the generator masked language model task.
generator_loss_weight: float = 1.0
# Weight for the replaced token detection task.
discriminator_rtd_loss_weight: float = 5.0
# Weight for the multi-word selection task.
discriminator_mws_loss_weight: float = 2.0
# Whether share embedding network between generator and discriminator.
tie_embeddings: bool = True
# Number of bottom layers shared between generator and discriminator.
# Non-positive value implies no sharing.
num_shared_generator_hidden_layers: int = 3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers: int = 11
disallow_correct: bool = False
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
@gin.configurable
def get_encoder(bert_config,
encoder_scaffold_cls,
embedding_inst=None,
hidden_inst=None):
embedding_network=None,
hidden_layers=layers.Transformer):
"""Gets a 'EncoderScaffold' object.
Args:
bert_config: A 'modeling.BertConfig'.
encoder_scaffold_cls: An EncoderScaffold class.
embedding_inst: Embedding instance.
hidden_inst: List of hidden layer instances.
embedding_network: Embedding network instance.
hidden_layers: List of hidden layer instances.
Returns:
A encoder object.
"""
if embedding_inst is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
embedding_width=bert_config.embedding_size,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.hidden_dropout_prob,
)
embedding_inst = networks.PackedSequenceEmbedding(**embedding_cfg)
# embedding_size is required for PackedSequenceEmbedding.
if bert_config.embedding_size is None:
bert_config.embedding_size = bert_config.hidden_size
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
embedding_width=bert_config.embedding_size,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
intermediate_activation=tf_utils.get_activation(
bert_config.hidden_activation),
dropout_rate=bert_config.dropout_rate,
attention_dropout_rate=bert_config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
)
if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_inst,
hidden_cls=hidden_inst,
embedding_cls=embedding_network,
hidden_cls=hidden_layers,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,
num_hidden_instances=bert_config.num_layers,
pooled_output_dim=bert_config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
stddev=bert_config.initializer_range),
dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments.
return encoder_scaffold_cls(**kwargs)
return networks.encoder_scaffold.EncoderScaffold(**kwargs)
......@@ -228,16 +228,14 @@ class TeamsPretrainer(tf.keras.Model):
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
candidate_size: Candidate size for multi-word selection task,
including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
def __init__(self,
......@@ -311,12 +309,15 @@ class TeamsPretrainer(tf.keras.Model):
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task.
(3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
(2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: A `[batch_size, sequence_length]` tensor indicating
(3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
"""
input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask']
......@@ -326,9 +327,6 @@ class TeamsPretrainer(tf.keras.Model):
# Runs generator.
sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
......@@ -353,11 +351,14 @@ class TeamsPretrainer(tf.keras.Model):
disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1],
masked_lm_positions,
disc_mws_candidates)
disc_mws_label = tf.zeros_like(masked_lm_positions, dtype=tf.int32)
outputs = {
'lm_outputs': lm_outputs,
'disc_rtd_logits': disc_rtd_logits,
'disc_rtd_label': disc_rtd_label,
'disc_mws_logits': disc_mws_logits,
'disc_mws_label': disc_mws_label,
}
return outputs
......@@ -408,7 +409,7 @@ class TeamsPretrainer(tf.keras.Model):
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
items = dict(encoder=self.discriminator_mws_network)
return items
def get_config(self):
......@@ -419,7 +420,7 @@ class TeamsPretrainer(tf.keras.Model):
return cls(**config)
def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False):
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
......@@ -429,8 +430,8 @@ def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False):
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using approximate iterative approach
which is faster.
use_topk: Whether to use tf.nn.top_k or using iterative approach where the
latter is empirically faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
......
......@@ -103,6 +103,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_logits = outputs['disc_rtd_logits']
disc_rtd_label = outputs['disc_rtd_label']
disc_mws_logits = outputs['disc_mws_logits']
disc_mws_label = outputs['disc_mws_label']
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
......@@ -111,6 +112,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
expected_disc_disc_mws_logits_shape = [
None, num_token_predictions, candidate_size
]
expected_disc_disc_mws_label_shape = [None, num_token_predictions]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
self.assertAllEqual(expected_disc_rtd_logits_shape,
disc_rtd_logits.shape.as_list())
......@@ -118,6 +120,8 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_label.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_logits_shape,
disc_mws_logits.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_label_shape,
disc_mws_label.shape.as_list())
def test_teams_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment