"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "8f2e5c90ce0f55877eddb1f7fee8f8b48004849b"
Commit 8754fa31 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321238481
parent b860406a
...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config): ...@@ -34,6 +34,8 @@ class ELECTRAPretrainerConfig(base_config.Config):
sequence_length: int = 512 sequence_length: int = 512
num_classes: int = 2 num_classes: int = 2
discriminator_loss_weight: float = 50.0 discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.TransformerEncoderConfig = ( generator_encoder: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig()) encoders.TransformerEncoderConfig())
discriminator_encoder: encoders.TransformerEncoderConfig = ( discriminator_encoder: encoders.TransformerEncoderConfig = (
...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg( ...@@ -60,23 +62,30 @@ def instantiate_pretrainer_from_cfg(
"""Instantiates ElectraPretrainer from the config.""" """Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder discriminator_encoder_cfg = config.discriminator_encoder
if generator_network is None: # Copy discriminator's embeddings to generator for easier model serialization.
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
if discriminator_network is None: if discriminator_network is None:
discriminator_network = encoders.instantiate_encoder_from_cfg( discriminator_network = encoders.instantiate_encoder_from_cfg(
discriminator_encoder_cfg) discriminator_encoder_cfg)
if generator_network is None:
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.instantiate_encoder_from_cfg(
generator_encoder_cfg)
return electra_pretrainer.ElectraPretrainer( return electra_pretrainer.ElectraPretrainer(
generator_network=generator_network, generator_network=generator_network,
discriminator_network=discriminator_network, discriminator_network=discriminator_network,
vocab_size=config.generator_encoder.vocab_size, vocab_size=config.generator_encoder.vocab_size,
num_classes=config.num_classes, num_classes=config.num_classes,
sequence_length=config.sequence_length, sequence_length=config.sequence_length,
last_hidden_dim=config.generator_encoder.hidden_size,
num_token_predictions=config.num_masked_tokens, num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation( mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation), generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range), stddev=generator_encoder_cfg.initializer_range),
classification_heads=instantiate_classification_heads_from_cfgs( classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads)) config.cls_heads),
disallow_correct=config.disallow_correct)
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
Includes configurations and instantiation methods. Includes configurations and instantiation methods.
""" """
from typing import Optional
import dataclasses import dataclasses
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -40,11 +41,13 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None
@gin.configurable def instantiate_encoder_from_cfg(
def instantiate_encoder_from_cfg(config: TransformerEncoderConfig, config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder): encoder_cls=networks.TransformerEncoder,
embedding_layer: Optional[layers.OnDeviceEmbedding] = None):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
if encoder_cls.__name__ == "EncoderScaffold": if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict( embedding_cfg = dict(
...@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig, ...@@ -91,5 +94,7 @@ def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
max_sequence_length=config.max_position_embeddings, max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range)) stddev=config.initializer_range),
embedding_width=config.embedding_size,
embedding_layer=embedding_layer)
return encoder_network return encoder_network
...@@ -48,7 +48,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -48,7 +48,6 @@ class ElectraPretrainer(tf.keras.Model):
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
...@@ -66,7 +65,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -66,7 +65,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length, sequence_length,
last_hidden_dim,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length, 'sequence_length': sequence_length,
'last_hidden_dim': last_hidden_dim,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -95,7 +92,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -95,7 +92,6 @@ class ElectraPretrainer(tf.keras.Model):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length self.sequence_length = sequence_length
self.last_hidden_dim = last_hidden_dim
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
...@@ -108,10 +104,15 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -108,10 +104,15 @@ class ElectraPretrainer(tf.keras.Model):
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=last_hidden_dim, inner_dim=generator_network._config_dict['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=mlm_initializer,
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network._config_dict['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1, kernel_initializer=mlm_initializer)
...@@ -165,7 +166,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -165,7 +166,8 @@ class ElectraPretrainer(tf.keras.Model):
if isinstance(disc_sequence_output, list): if isinstance(disc_sequence_output, list):
disc_sequence_output = disc_sequence_output[-1] disc_sequence_output = disc_sequence_output[-1]
disc_logits = self.discriminator_head(disc_sequence_output) disc_logits = self.discriminator_head(
self.discriminator_projection(disc_sequence_output))
disc_logits = tf.squeeze(disc_logits, axis=-1) disc_logits = tf.squeeze(disc_logits, axis=-1)
outputs = { outputs = {
...@@ -214,6 +216,12 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -214,6 +216,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens': sampled_tokens 'sampled_tokens': sampled_tokens
} }
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_network)
return items
def get_config(self): def get_config(self):
return self._config return self._config
......
...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -49,7 +49,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length, sequence_length=sequence_length,
last_hidden_dim=768,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -101,7 +100,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -140,7 +138,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size=100, vocab_size=100,
num_classes=2, num_classes=2,
sequence_length=3, sequence_length=3,
last_hidden_dim=768,
num_token_predictions=2) num_token_predictions=2)
# Create another BERT trainer via serialization and deserialization. # Create another BERT trainer via serialization and deserialization.
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.data import pretrain_dataloader
@dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig):
"""The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig(
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
])
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(ELECTRAPretrainConfig)
class ELECTRAPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self):
return electra.instantiate_pretrainer_from_cfg(
self.task_config.model)
def build_losses(self,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
# generator lm and (optional) nsp loss.
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_outputs'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in labels:
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
# discriminator replaced token detection (rtd) loss.
rtd_logits = model_outputs['disc_logits']
rtd_labels = tf.cast(model_outputs['disc_label'], tf.float32)
input_mask = tf.cast(labels['input_mask'], tf.float32)
rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=rtd_logits, labels=rtd_labels)
rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss)
rtd_denominator = tf.reduce_sum(input_mask)
rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator)
metrics['discriminator_loss'].update_state(rtd_loss)
total_loss = total_loss + \
self.task_config.model.discriminator_loss_weight * rtd_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
metrics['total_loss'].update_state(total_loss)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='discriminator_accuracy'),
]
if self.task_config.train_data.use_next_sentence_label:
metrics.append(
tf.keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
metrics.append(tf.keras.metrics.Mean(name='discriminator_loss'))
metrics.append(tf.keras.metrics.Mean(name='total_loss'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_outputs'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['sentence_outputs'])
if 'discriminator_accuracy' in metrics:
disc_logits_expanded = tf.expand_dims(model_outputs['disc_logits'], -1)
discrim_full_logits = tf.concat(
[-1.0 * disc_logits_expanded, disc_logits_expanded], -1)
metrics['discriminator_accuracy'].update_state(
model_outputs['disc_label'], discrim_full_logits,
labels['input_mask'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.electra_task."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase):
def test_task(self):
config = electra_task.ELECTRAPretrainConfig(
model=electra.ELECTRAPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
discriminator_encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
sequence_length=128,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment