"new-tutorial/blitz/3_message_passing.py" did not exist on "3ca52b1e5964742652ff2ebe472a2ffafc1812fe"
Commit c68dbef0 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 400408816
parent 9a1b54cd
task: task:
model: model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, candidate_size: 5
name: next_sentence, num_classes: 2}] num_shared_generator_hidden_layers: 3
generator_encoder: num_discriminator_task_agnostic_layers: 11
bert:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 256
initializer_range: 0.02
intermediate_size: 1024
max_position_embeddings: 512
num_attention_heads: 4
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
num_masked_tokens: 76
sequence_length: 512
num_classes: 2
discriminator_encoder:
bert:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
discriminator_loss_weight: 50.0
disallow_correct: false
tie_embeddings: true tie_embeddings: true
generator:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 6
type_vocab_size: 2
vocab_size: 30522
discriminator:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 6
type_vocab_size: 2
vocab_size: 30522
train_data: train_data:
drop_remainder: true drop_remainder: true
global_batch_size: 256 global_batch_size: 256
...@@ -55,8 +49,8 @@ task: ...@@ -55,8 +49,8 @@ task:
use_next_sentence_label: false use_next_sentence_label: false
use_position_id: false use_position_id: false
trainer: trainer:
checkpoint_interval: 6000 checkpoint_interval: 4000
max_to_keep: 50 max_to_keep: 5
optimizer_config: optimizer_config:
learning_rate: learning_rate:
polynomial: polynomial:
...@@ -73,8 +67,8 @@ trainer: ...@@ -73,8 +67,8 @@ trainer:
power: 1 power: 1
warmup_steps: 10000 warmup_steps: 10000
type: polynomial type: polynomial
steps_per_loop: 1000 steps_per_loop: 4000
summary_interval: 1000 summary_interval: 4000
train_steps: 1000000 train_steps: 1000000
validation_interval: 100 validation_interval: 100
validation_steps: 64 validation_steps: 64
...@@ -21,6 +21,8 @@ from official.modeling import tf_utils ...@@ -21,6 +21,8 @@ from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import models from official.nlp.modeling import models
_LOGIT_PENALTY_MULTIPLIER = 10000
class ReplacedTokenDetectionHead(tf.keras.layers.Layer): class ReplacedTokenDetectionHead(tf.keras.layers.Layer):
"""Replaced token detection discriminator head. """Replaced token detection discriminator head.
...@@ -273,10 +275,9 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -273,10 +275,9 @@ class TeamsPretrainer(tf.keras.Model):
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
self.output_type = output_type self.output_type = output_type
self.embedding_table = (
self.discriminator_mws_network.embedding_network.get_embedding_table())
self.masked_lm = layers.MaskedLM( self.masked_lm = layers.MaskedLM(
embedding_table=self.embedding_table, embedding_table=self.generator_network.embedding_network
.get_embedding_table(),
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=mlm_initializer,
output=output_type, output=output_type,
...@@ -290,7 +291,8 @@ class TeamsPretrainer(tf.keras.Model): ...@@ -290,7 +291,8 @@ class TeamsPretrainer(tf.keras.Model):
name='discriminator_rtd') name='discriminator_rtd')
hidden_cfg = discriminator_cfg['hidden_cfg'] hidden_cfg = discriminator_cfg['hidden_cfg']
self.discriminator_mws_head = MultiWordSelectionHead( self.discriminator_mws_head = MultiWordSelectionHead(
embedding_table=self.embedding_table, embedding_table=self.discriminator_mws_network.embedding_network
.get_embedding_table(),
activation=hidden_cfg['intermediate_activation'], activation=hidden_cfg['intermediate_activation'],
initializer=hidden_cfg['kernel_initializer'], initializer=hidden_cfg['kernel_initializer'],
output=output_type, output=output_type,
...@@ -436,7 +438,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False): ...@@ -436,7 +438,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
""" """
if use_topk: if use_topk:
if disallow is not None: if disallow is not None:
logits -= 10000.0 * disallow logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform( uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1) tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9) gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
...@@ -445,7 +447,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False): ...@@ -445,7 +447,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
sampled_tokens_list = [] sampled_tokens_list = []
vocab_size = tf_utils.get_shape_list(logits)[-1] vocab_size = tf_utils.get_shape_list(logits)[-1]
if disallow is not None: if disallow is not None:
logits -= 10000.0 * disallow logits -= _LOGIT_PENALTY_MULTIPLIER * disallow
uniform_noise = tf.random.uniform( uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1) tf_utils.get_shape_list(logits), minval=0, maxval=1)
...@@ -454,7 +456,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False): ...@@ -454,7 +456,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
for _ in range(k): for _ in range(k):
token_ids = tf.argmax(logits, -1, output_type=tf.int32) token_ids = tf.argmax(logits, -1, output_type=tf.int32)
sampled_tokens_list.append(token_ids) sampled_tokens_list.append(token_ids)
logits -= 10000.0 * tf.one_hot( logits -= _LOGIT_PENALTY_MULTIPLIER * tf.one_hot(
token_ids, depth=vocab_size, dtype=tf.float32) token_ids, depth=vocab_size, dtype=tf.float32)
sampled_tokens = tf.stack(sampled_tokens_list, -1) sampled_tokens = tf.stack(sampled_tokens_list, -1)
return sampled_tokens return sampled_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment