Commit 5e5cdec3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393909662
parent b8a51be8
...@@ -21,74 +21,84 @@ import tensorflow as tf ...@@ -21,74 +21,84 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
class TeamsPretrainerConfig(base_config.Config): class TeamsPretrainerConfig(base_config.Config):
"""Teams pretrainer configuration.""" """Teams pretrainer configuration."""
num_masked_tokens: int = 76 # Candidate size for multi-word selection task, including the correct word.
sequence_length: int = 512 candidate_size: int = 5
num_classes: int = 2 # Weight for the generator masked language model task.
discriminator_loss_weight: float = 50.0 generator_loss_weight: float = 1.0
# Weight for the replaced token detection task.
discriminator_rtd_loss_weight: float = 5.0
# Weight for the multi-word selection task.
discriminator_mws_loss_weight: float = 2.0
# Whether share embedding network between generator and discriminator. # Whether share embedding network between generator and discriminator.
tie_embeddings: bool = True tie_embeddings: bool = True
# Number of bottom layers shared between generator and discriminator. # Number of bottom layers shared between generator and discriminator.
# Non-positive value implies no sharing.
num_shared_generator_hidden_layers: int = 3 num_shared_generator_hidden_layers: int = 3
# Number of bottom layers shared between different discriminator tasks. # Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers: int = 11 num_discriminator_task_agnostic_layers: int = 11
disallow_correct: bool = False
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig() discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
@gin.configurable @gin.configurable
def get_encoder(bert_config, def get_encoder(bert_config,
encoder_scaffold_cls, embedding_network=None,
embedding_inst=None, hidden_layers=layers.Transformer):
hidden_inst=None):
"""Gets a 'EncoderScaffold' object. """Gets a 'EncoderScaffold' object.
Args: Args:
bert_config: A 'modeling.BertConfig'. bert_config: A 'modeling.BertConfig'.
encoder_scaffold_cls: An EncoderScaffold class. embedding_network: Embedding network instance.
embedding_inst: Embedding instance. hidden_layers: List of hidden layer instances.
hidden_inst: List of hidden layer instances.
Returns: Returns:
A encoder object. A encoder object.
""" """
if embedding_inst is not None: # embedding_size is required for PackedSequenceEmbedding.
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin. if bert_config.embedding_size is None:
embedding_cfg = dict( bert_config.embedding_size = bert_config.hidden_size
vocab_size=bert_config.vocab_size, embedding_cfg = dict(
type_vocab_size=bert_config.type_vocab_size, vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size, type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size, hidden_size=bert_config.hidden_size,
max_seq_length=bert_config.max_position_embeddings, embedding_width=bert_config.embedding_size,
initializer=tf.keras.initializers.TruncatedNormal( max_seq_length=bert_config.max_position_embeddings,
stddev=bert_config.initializer_range), initializer=tf.keras.initializers.TruncatedNormal(
dropout_rate=bert_config.hidden_dropout_prob, stddev=bert_config.initializer_range),
) dropout_rate=bert_config.dropout_rate,
embedding_inst = networks.PackedSequenceEmbedding(**embedding_cfg) )
hidden_cfg = dict( hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads, num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size, intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act), intermediate_activation=tf_utils.get_activation(
dropout_rate=bert_config.hidden_dropout_prob, bert_config.hidden_activation),
attention_dropout_rate=bert_config.attention_probs_dropout_prob, dropout_rate=bert_config.dropout_rate,
attention_dropout_rate=bert_config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal( kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
) )
if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
embedding_cls=embedding_inst, embedding_cls=embedding_network,
hidden_cls=hidden_inst, hidden_cls=hidden_layers,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers, num_hidden_instances=bert_config.num_layers,
pooled_output_dim=bert_config.hidden_size, pooled_output_dim=bert_config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)) stddev=bert_config.initializer_range),
dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments. # Relies on gin configuration to define the Transformer encoder arguments.
return encoder_scaffold_cls(**kwargs) return networks.encoder_scaffold.EncoderScaffold(**kwargs)
...@@ -228,16 +228,14 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -228,16 +228,14 @@ class TeamsPretrainer(tf.keras.Model):
num_discriminator_task_agnostic_layers: Number of layers shared between num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators. multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network candidate_size: Candidate size for multi-word selection task,
for the generator network (not used now) including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or output_type: The output style for this network. Can be either `logits` or
`predictions`. `predictions`.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
""" """
def __init__(self, def __init__(self,
...@@ -311,12 +309,15 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -311,12 +309,15 @@ class TeamsPretrainer(tf.keras.Model):
outputs: A dict of pretrainer model outputs, including outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]` (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions. tensor indicating logits on masked positions.
(2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating (2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for nsp task.
(3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task. logits for discriminator replaced token detection task.
(4) disc_label: A `[batch_size, sequence_length]` tensor indicating (3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task. target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
""" """
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
...@@ -326,9 +327,6 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -326,9 +327,6 @@ class TeamsPretrainer(tf.keras.Model):
# Runs generator. # Runs generator.
sequence_output = self.generator_network( sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output'] [input_word_ids, input_mask, input_type_ids])['sequence_output']
# The generator encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions) lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
...@@ -353,11 +351,14 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -353,11 +351,14 @@ class TeamsPretrainer(tf.keras.Model):
disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1], disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1],
masked_lm_positions, masked_lm_positions,
disc_mws_candidates) disc_mws_candidates)
disc_mws_label = tf.zeros_like(masked_lm_positions, dtype=tf.int32)
outputs = { outputs = {
'lm_outputs': lm_outputs, 'lm_outputs': lm_outputs,
'disc_rtd_logits': disc_rtd_logits, 'disc_rtd_logits': disc_rtd_logits,
'disc_rtd_label': disc_rtd_label, 'disc_rtd_label': disc_rtd_label,
'disc_mws_logits': disc_mws_logits, 'disc_mws_logits': disc_mws_logits,
'disc_mws_label': disc_mws_label,
} }
return outputs return outputs
...@@ -408,7 +409,7 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -408,7 +409,7 @@ class TeamsPretrainer(tf.keras.Model):
@property @property
def checkpoint_items(self): def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network) items = dict(encoder=self.discriminator_mws_network)
return items return items
def get_config(self): def get_config(self):
...@@ -419,7 +420,7 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -419,7 +420,7 @@ class TeamsPretrainer(tf.keras.Model):
return cls(**config) return cls(**config)
def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False): def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""Implement softmax sampling using gumbel softmax trick to select k items. """Implement softmax sampling using gumbel softmax trick to select k items.
Args: Args:
...@@ -429,8 +430,8 @@ def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False): ...@@ -429,8 +430,8 @@ def sample_k_from_softmax(logits, k=5, disallow=None, use_topk=False):
disallow: If `None`, we directly sample tokens from the logits. Otherwise, disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size] this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position. indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using approximate iterative approach use_topk: Whether to use tf.nn.top_k or using iterative approach where the
which is faster. latter is empirically faster.
Returns: Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
......
...@@ -103,6 +103,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase): ...@@ -103,6 +103,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_logits = outputs['disc_rtd_logits'] disc_rtd_logits = outputs['disc_rtd_logits']
disc_rtd_label = outputs['disc_rtd_label'] disc_rtd_label = outputs['disc_rtd_label']
disc_mws_logits = outputs['disc_mws_logits'] disc_mws_logits = outputs['disc_mws_logits']
disc_mws_label = outputs['disc_mws_label']
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
...@@ -111,6 +112,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase): ...@@ -111,6 +112,7 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
expected_disc_disc_mws_logits_shape = [ expected_disc_disc_mws_logits_shape = [
None, num_token_predictions, candidate_size None, num_token_predictions, candidate_size
] ]
expected_disc_disc_mws_label_shape = [None, num_token_predictions]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list()) self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
self.assertAllEqual(expected_disc_rtd_logits_shape, self.assertAllEqual(expected_disc_rtd_logits_shape,
disc_rtd_logits.shape.as_list()) disc_rtd_logits.shape.as_list())
...@@ -118,6 +120,8 @@ class TeamsPretrainerTest(keras_parameterized.TestCase): ...@@ -118,6 +120,8 @@ class TeamsPretrainerTest(keras_parameterized.TestCase):
disc_rtd_label.shape.as_list()) disc_rtd_label.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_logits_shape, self.assertAllEqual(expected_disc_disc_mws_logits_shape,
disc_mws_logits.shape.as_list()) disc_mws_logits.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_label_shape,
disc_mws_label.shape.as_list())
def test_teams_trainer_tensor_call(self): def test_teams_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment