Commit 8754fa31 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321238481
parent b860406a
......@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length: int = 512
num_classes: int = 2
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = (
......@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
if generator_network is None:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
# Copy discriminator's embeddings to generator for easier model serialization.
if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
last_hidden_dim=config.generator_encoder.hidden_size,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
config.cls_heads),
disallow_correct=config.disallow_correct)
......@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
"""
from typing import Optional
import dataclasses
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
@gin.configurable
def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder):
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
......@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network
......@@ -48,7 +48,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
......@@ -66,7 +65,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size,
num_classes,
sequence_length,
last_hidden_dim,
num_token_predictions,
mlm_activation=None,
mlm_initializer='glorot_uniform',
......@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size': vocab_size,
'num_classes': num_classes,
'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer,
......@@ -95,7 +92,6 @@ class ElectraPretrainer(tf.keras.Model):
self.vocab_size = vocab_size
self.num_classes = num_classes
self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer
......@@ -108,10 +104,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type,
name='generator_masked_lm')
self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim,
inner_dim=generator_network._config_dict['hidden_size'],
num_classes=num_classes,
initializer=mlm_initializer,
name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer)
......@@ -165,7 +166,8 @@ class ElectraPretrainer(tf.keras.Model):
if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output)
disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = {
......@@ -214,6 +216,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens
}
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self):
return self._config
......
......@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size,
num_classes=num_classes,
sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions,
disallow_correct=True)
......@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100,
num_classes=2,
sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model.
......@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100,
num_classes=2,
sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization.
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.data import pretrain_dataloader
@dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig):
"""The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig(
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
])
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(ELECTRAPretrainConfig)
class ELECTRAPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self):
return electra.instantiate_pretrainer_from_cfg(
self.task_config.model)
def build_losses(self,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
# generator lm and (optional) nsp loss.
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_outputs'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
# discriminator replaced token detection (rtd) loss.
rtd_logits = model_outputs['disc_logits']
rtd_labels = tf.cast(model_outputs['disc_label'], tf.float32)
input_mask = tf.cast(labels['input_mask'], tf.float32)
rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=rtd_logits, labels=rtd_labels)
rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss)
rtd_denominator = tf.reduce_sum(input_mask)
rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator)
metrics['discriminator_loss'].update_state(rtd_loss)
total_loss = total_loss + \
self.task_config.model.discriminator_loss_weight * rtd_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
metrics['total_loss'].update_state(total_loss)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='discriminator_accuracy'),
]
if self.task_config.train_data.use_next_sentence_label:
metrics.append(
tf.keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
metrics.append(tf.keras.metrics.Mean(name='discriminator_loss'))
metrics.append(tf.keras.metrics.Mean(name='total_loss'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_outputs'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['sentence_outputs'])
if 'discriminator_accuracy' in metrics:
disc_logits_expanded = tf.expand_dims(model_outputs['disc_logits'], -1)
discrim_full_logits = tf.concat(
[-1.0 * disc_logits_expanded, disc_logits_expanded], -1)
metrics['discriminator_accuracy'].update_state(
model_outputs['disc_label'], discrim_full_logits,
labels['input_mask'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.electra_task."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase):
def test_task(self):
config = electra_task.ELECTRAPretrainConfig(
model=electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
sequence_length=128,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment