Commit c68dbef0 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 400408816
parent 9a1b54cd
task:
model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768,
name: next_sentence, num_classes: 2}]
generator_encoder:
bert:
candidate_size: 5
num_shared_generator_hidden_layers: 3
num_discriminator_task_agnostic_layers: 11
tie_embeddings: true
generator:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 256
hidden_size: 768
initializer_range: 0.02
intermediate_size: 1024
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 4
num_layers: 12
num_attention_heads: 12
num_layers: 6
type_vocab_size: 2
vocab_size: 30522
num_masked_tokens: 76
sequence_length: 512
num_classes: 2
discriminator_encoder:
bert:
discriminator:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
......@@ -30,12 +27,9 @@ task:
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
num_layers: 6
type_vocab_size: 2
vocab_size: 30522
discriminator_loss_weight: 50.0
disallow_correct: false
tie_embeddings: true
train_data:
drop_remainder: true
global_batch_size: 256
......@@ -55,8 +49,8 @@ task:
use_next_sentence_label: false
use_position_id: false
trainer:
checkpoint_interval: 6000
max_to_keep: 50
checkpoint_interval: 4000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
......@@ -73,8 +67,8 @@ trainer:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
steps_per_loop: 4000
summary_interval: 4000
train_steps: 1000000
validation_interval: 100
validation_steps: 64
......@@ -21,6 +21,8 @@ from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
_LOGIT_PENALTY_MULTIPLIER = 10000
class ReplacedTokenDetectionHead(tf.keras.layers.Layer):
"""Replaced token detection discriminator head.
......@@ -273,10 +275,9 @@ class TeamsPretrainer(tf.keras.Model):
self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer
self.output_type = output_type
self.embedding_table = (
self.discriminator_mws_network.embedding_network.get_embedding_table())
self.masked_lm = layers.MaskedLM(
embedding_table=self.embedding_table,
embedding_table=self.generator_network.embedding_network
.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
output=output_type,
......@@ -290,7 +291,8 @@ class TeamsPretrainer(tf.keras.Model):
name='discriminator_rtd')
hidden_cfg = discriminator_cfg['hidden_cfg']
self.discriminator_mws_head = MultiWordSelectionHead(
embedding_table=self.embedding_table,
embedding_table=self.discriminator_mws_network.embedding_network
.get_embedding_table(),
activation=hidden_cfg['intermediate_activation'],
initializer=hidden_cfg['kernel_initializer'],
output=output_type,
......@@ -436,7 +438,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""
if use_topk:
if disallow is not None:
logits -= 10000.0 * disallow
logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
......@@ -445,7 +447,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
sampled_tokens_list = []
vocab_size = tf_utils.get_shape_list(logits)[-1]
if disallow is not None:
logits -= 10000.0 * disallow
logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
......@@ -454,7 +456,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
for _ in range(k):
token_ids = tf.argmax(logits, -1, output_type=tf.int32)
sampled_tokens_list.append(token_ids)
logits -= 10000.0 * tf.one_hot(
logits -= _LOGIT_PENALTY_MULTIPLIER * tf.one_hot(
token_ids, depth=vocab_size, dtype=tf.float32)
sampled_tokens = tf.stack(sampled_tokens_list, -1)
return sampled_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment