Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TEAMS model configurations and instantiation methods."""
import dataclasses
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling import networks
@dataclasses.dataclass
class TeamsPretrainerConfig(base_config.Config):
"""Teams pretrainer configuration."""
# Candidate size for multi-word selection task, including the correct word.
candidate_size: int = 5
# Weight for the generator masked language model task.
generator_loss_weight: float = 1.0
# Weight for the replaced token detection task.
discriminator_rtd_loss_weight: float = 5.0
# Weight for the multi-word selection task.
discriminator_mws_loss_weight: float = 2.0
# Whether share embedding network between generator and discriminator.
tie_embeddings: bool = True
# Number of bottom layers shared between generator and discriminator.
# Non-positive value implies no sharing.
num_shared_generator_hidden_layers: int = 3
# Number of bottom layers shared between different discriminator tasks.
num_discriminator_task_agnostic_layers: int = 11
generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
# Used for compatibility with continuous finetuning where common BERT config
# is used.
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
@gin.configurable
def get_encoder(bert_config,
embedding_network=None,
hidden_layers=layers.Transformer):
"""Gets a 'EncoderScaffold' object.
Args:
bert_config: A 'modeling.BertConfig'.
embedding_network: Embedding network instance.
hidden_layers: List of hidden layer instances.
Returns:
A encoder object.
"""
# embedding_size is required for PackedSequenceEmbedding.
if bert_config.embedding_size is None:
bert_config.embedding_size = bert_config.hidden_size
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size,
embedding_width=bert_config.embedding_size,
max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dropout_rate=bert_config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
bert_config.hidden_activation),
dropout_rate=bert_config.dropout_rate,
attention_dropout_rate=bert_config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
)
if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cls=hidden_layers,
hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_layers,
pooled_output_dim=bert_config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
dict_outputs=True)
# Relies on gin configuration to define the Transformer encoder arguments.
return networks.encoder_scaffold.EncoderScaffold(**kwargs)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""TEAMS experiments."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.teams import teams_task
AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass
class TeamsOptimizationConfig(optimization.OptimizationConfig):
"""TEAMS optimization config."""
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type="adamw",
adamw=AdamWeightDecay(
weight_decay_rate=0.01,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig(
type="polynomial",
polynomial=PolynomialLr(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory("teams/pretraining")
def teams_pretrain() -> cfg.ExperimentConfig:
"""TEAMS pretraining."""
config = cfg.ExperimentConfig(
task=teams_task.TeamsPretrainTaskConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=cfg.TrainerConfig(
optimizer_config=TeamsOptimizationConfig(), train_steps=1000000),
restrictions=[
"task.train_data.is_training != None",
"task.validation_data.is_training != None"
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for teams_experiments."""
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import exp_factory
class TeamsExperimentsTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('teams/pretraining',))
def test_teams_experiments(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for TEAMS models."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
class ReplacedTokenDetectionHead(tf.keras.layers.Layer):
"""Replaced token detection discriminator head.
Arguments:
encoder_cfg: Encoder config, used to create hidden layers and head.
num_task_agnostic_layers: Number of task agnostic layers in the
discriminator.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
encoder_cfg,
num_task_agnostic_layers,
output='logits',
name='rtd',
**kwargs):
super(ReplacedTokenDetectionHead, self).__init__(name=name, **kwargs)
self.num_task_agnostic_layers = num_task_agnostic_layers
self.hidden_size = encoder_cfg['embedding_cfg']['hidden_size']
self.num_hidden_instances = encoder_cfg['num_hidden_instances']
self.hidden_cfg = encoder_cfg['hidden_cfg']
self.activation = self.hidden_cfg['intermediate_activation']
self.initializer = self.hidden_cfg['kernel_initializer']
self.hidden_layers = []
for i in range(self.num_task_agnostic_layers, self.num_hidden_instances):
self.hidden_layers.append(
layers.Transformer(
num_attention_heads=self.hidden_cfg['num_attention_heads'],
intermediate_size=self.hidden_cfg['intermediate_size'],
intermediate_activation=self.activation,
dropout_rate=self.hidden_cfg['dropout_rate'],
attention_dropout_rate=self.hidden_cfg['attention_dropout_rate'],
kernel_initializer=self.initializer,
name='transformer/layer_%d_rtd' % i))
self.dense = tf.keras.layers.Dense(
self.hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/rtd_dense')
self.rtd_head = tf.keras.layers.Dense(
units=1, kernel_initializer=self.initializer,
name='transform/rtd_head')
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def call(self, sequence_data, input_mask):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
input_mask: A [batch_size, seq_length] binary mask to separate the input
from the padding.
Returns:
A [batch_size, seq_length] tensor.
"""
attention_mask = layers.SelfAttentionMask()([sequence_data, input_mask])
data = sequence_data
for hidden_layer in self.hidden_layers:
data = hidden_layer([sequence_data, attention_mask])
rtd_logits = self.rtd_head(self.dense(data))
return tf.squeeze(rtd_logits, axis=-1)
class MultiWordSelectionHead(tf.keras.layers.Layer):
"""Multi-word selection discriminator head.
Arguments:
embedding_table: The embedding table.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name='mws',
**kwargs):
super(MultiWordSelectionHead, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
self._vocab_size, self.embed_size = self.embedding_table.shape
self.dense = tf.keras.layers.Dense(
self.embed_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/mws_dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/mws_layernorm')
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def call(self, sequence_data, masked_positions, candidate_sets):
"""Compute inner-products of hidden vectors with sampled element embeddings.
Args:
sequence_data: A [batch_size, seq_length, num_hidden] tensor.
masked_positions: A [batch_size, num_prediction] tensor.
candidate_sets: A [batch_size, num_prediction, k] tensor.
Returns:
A [batch_size, num_prediction, k] tensor.
"""
# Gets shapes for later usage
candidate_set_shape = tf_utils.get_shape_list(candidate_sets)
num_prediction = candidate_set_shape[1]
# Gathers hidden vectors -> (batch_size, num_prediction, 1, embed_size)
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.expand_dims(
tf.reshape(lm_data, [-1, num_prediction, self.embed_size]), 2)
# Gathers embeddings -> (batch_size, num_prediction, embed_size, k)
flat_candidate_sets = tf.reshape(candidate_sets, [-1])
candidate_embeddings = tf.gather(self.embedding_table, flat_candidate_sets)
candidate_embeddings = tf.reshape(
candidate_embeddings,
tf.concat([tf.shape(candidate_sets), [self.embed_size]], axis=0)
)
candidate_embeddings.set_shape(
candidate_sets.shape.as_list() + [self.embed_size])
candidate_embeddings = tf.transpose(candidate_embeddings, [0, 1, 3, 2])
# matrix multiplication + squeeze -> (batch_size, num_prediction, k)
logits = tf.matmul(lm_data, candidate_embeddings)
logits = tf.squeeze(logits, 2)
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of shape
(`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
number of hidden units.
positions: Positions ids of tokens in batched sequences.
Returns:
Sequence tensor of shape (batch_size * num_predictions,
num_hidden).
"""
sequence_shape = tf_utils.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size, seq_length, width = sequence_shape
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
class TeamsPretrainer(tf.keras.Model):
"""TEAMS network training model.
This is an implementation of the network structure described in "Training
ELECTRA Augmented with Multi-word Selection"
(https://arxiv.org/abs/2106.00139).
The TeamsPretrainer allows a user to pass in two transformer encoders, one
for generator, the other for discriminator (multi-word selection). The
pretrainer then instantiates the masked language model (at generator side) and
classification networks (including both multi-word selection head and replaced
token detection head) that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer encoder for generator, this network should
output a sequence output.
discriminator_mws_network: A transformer encoder for multi-word selection
discriminator, this network should output a sequence output.
num_discriminator_task_agnostic_layers: Number of layers shared between
multi-word selection and random token detection discriminators.
vocab_size: Size of generator output vocabulary
candidate_size: Candidate size for multi-word selection task,
including the correct word.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either `logits` or
`predictions`.
"""
def __init__(self,
generator_network,
discriminator_mws_network,
num_discriminator_task_agnostic_layers,
vocab_size,
candidate_size=5,
mlm_activation=None,
mlm_initializer='glorot_uniform',
output_type='logits',
**kwargs):
super().__init__()
self._config = {
'generator_network':
generator_network,
'discriminator_mws_network':
discriminator_mws_network,
'num_discriminator_task_agnostic_layers':
num_discriminator_task_agnostic_layers,
'vocab_size':
vocab_size,
'candidate_size':
candidate_size,
'mlm_activation':
mlm_activation,
'mlm_initializer':
mlm_initializer,
'output_type':
output_type,
}
for k, v in kwargs.items():
self._config[k] = v
self.generator_network = generator_network
self.discriminator_mws_network = discriminator_mws_network
self.vocab_size = vocab_size
self.candidate_size = candidate_size
self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer
self.output_type = output_type
self.embedding_table = (
self.discriminator_mws_network.embedding_network.get_embedding_table())
self.masked_lm = layers.MaskedLM(
embedding_table=self.embedding_table,
activation=mlm_activation,
initializer=mlm_initializer,
output=output_type,
name='generator_masked_lm')
discriminator_cfg = self.discriminator_mws_network.get_config()
self.num_task_agnostic_layers = num_discriminator_task_agnostic_layers
self.discriminator_rtd_head = ReplacedTokenDetectionHead(
encoder_cfg=discriminator_cfg,
num_task_agnostic_layers=self.num_task_agnostic_layers,
output=output_type,
name='discriminator_rtd')
hidden_cfg = discriminator_cfg['hidden_cfg']
self.discriminator_mws_head = MultiWordSelectionHead(
embedding_table=self.embedding_table,
activation=hidden_cfg['intermediate_activation'],
initializer=hidden_cfg['kernel_initializer'],
output=output_type,
name='discriminator_mws')
def call(self, inputs):
"""TEAMS forward pass.
Args:
inputs: A dict of all inputs, same as the standard BERT model.
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
tensor indicating logits on masked positions.
(2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task.
(3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task.
(4) disc_mws_logits: A `[batch_size, num_token_predictions,
candidate_size]` tensor indicating logits for discriminator multi-word
selection task.
(5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
indicating target labels for discriminator multi-word selection task.
"""
input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask']
input_type_ids = inputs['input_type_ids']
masked_lm_positions = inputs['masked_lm_positions']
# Runs generator.
sequence_output = self.generator_network(
[input_word_ids, input_mask, input_type_ids])['sequence_output']
lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
# Samples tokens from generator.
fake_data = self._get_fake_data(inputs, lm_outputs)
# Runs discriminator.
disc_input = fake_data['inputs']
disc_rtd_label = fake_data['is_fake_tokens']
disc_mws_candidates = fake_data['candidate_set']
mws_sequence_outputs = self.discriminator_mws_network([
disc_input['input_word_ids'], disc_input['input_mask'],
disc_input['input_type_ids']
])['encoder_outputs']
# Applies replaced token detection with input selected based on
# self.num_discriminator_task_agnostic_layers
disc_rtd_logits = self.discriminator_rtd_head(
mws_sequence_outputs[self.num_task_agnostic_layers - 1], input_mask)
# Applies multi-word selection.
disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1],
masked_lm_positions,
disc_mws_candidates)
disc_mws_label = tf.zeros_like(masked_lm_positions, dtype=tf.int32)
outputs = {
'lm_outputs': lm_outputs,
'disc_rtd_logits': disc_rtd_logits,
'disc_rtd_label': disc_rtd_label,
'disc_mws_logits': disc_mws_logits,
'disc_mws_label': disc_mws_label,
}
return outputs
def _get_fake_data(self, inputs, mlm_logits):
"""Generate corrupted data for discriminator.
Note it is poosible for sampled token to be the same as the correct one.
Args:
inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits
Returns:
A dict of generated fake data
"""
inputs = models.electra_pretrainer.unmask(inputs, duplicate=True)
# Samples replaced token.
sampled_tokens = tf.stop_gradient(
models.electra_pretrainer.sample_from_softmax(
mlm_logits, disallow=None))
sampled_tokids = tf.argmax(sampled_tokens, axis=-1, output_type=tf.int32)
# Prepares input and label for replaced token detection task.
updated_input_ids, masked = models.electra_pretrainer.scatter_update(
inputs['input_word_ids'], sampled_tokids, inputs['masked_lm_positions'])
rtd_labels = masked * (1 - tf.cast(
tf.equal(updated_input_ids, inputs['input_word_ids']), tf.int32))
updated_inputs = models.electra_pretrainer.get_updated_inputs(
inputs, duplicate=True, input_word_ids=updated_input_ids)
# Samples (candidate_size-1) negatives and concat with true tokens
disallow = tf.one_hot(
inputs['masked_lm_ids'], depth=self.vocab_size, dtype=tf.float32)
sampled_candidates = tf.stop_gradient(
sample_k_from_softmax(mlm_logits, k=self.candidate_size-1,
disallow=disallow))
true_token_id = tf.expand_dims(inputs['masked_lm_ids'], -1)
candidate_set = tf.concat([true_token_id, sampled_candidates], -1)
return {
'inputs': updated_inputs,
'is_fake_tokens': rtd_labels,
'sampled_tokens': sampled_tokens,
'candidate_set': candidate_set
}
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(encoder=self.discriminator_mws_network)
return items
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""Implement softmax sampling using gumbel softmax trick to select k items.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
the generator output logits for each masked position.
k: Number of samples
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
indicating the true word id in each masked position.
use_topk: Whether to use tf.nn.top_k or using iterative approach where the
latter is empirically faster.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
the sampled word id in each masked position.
"""
if use_topk:
if disallow is not None:
logits -= 10000.0 * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
_, sampled_tokens = tf.nn.top_k(logits + gumbel_noise, k=k, sorted=False)
else:
sampled_tokens_list = []
vocab_size = tf_utils.get_shape_list(logits)[-1]
if disallow is not None:
logits -= 10000.0 * disallow
uniform_noise = tf.random.uniform(
tf_utils.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
logits += gumbel_noise
for _ in range(k):
token_ids = tf.argmax(logits, -1, output_type=tf.int32)
sampled_tokens_list.append(token_ids)
logits -= 10000.0 * tf.one_hot(
token_ids, depth=vocab_size, dtype=tf.float32)
sampled_tokens = tf.stack(sampled_tokens_list, -1)
return sampled_tokens
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TEAMS pre trainer network."""
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp.modeling.networks import encoder_scaffold
from official.nlp.modeling.networks import packed_sequence_embedding
from official.nlp.projects.teams import teams_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TeamsPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the TEAMS trainer.
def _get_network(self, vocab_size):
sequence_length = 512
hidden_size = 50
embedding_cfg = {
'vocab_size': vocab_size,
'type_vocab_size': 1,
'hidden_size': hidden_size,
'embedding_width': hidden_size,
'max_seq_length': sequence_length,
'initializer': tf.keras.initializers.TruncatedNormal(stddev=0.02),
'dropout_rate': 0.1,
}
embedding_inst = packed_sequence_embedding.PackedSequenceEmbedding(
**embedding_cfg)
hidden_cfg = {
'num_attention_heads':
2,
'intermediate_size':
3072,
'intermediate_activation':
activations.gelu,
'dropout_rate':
0.1,
'attention_dropout_rate':
0.1,
'kernel_initializer':
tf.keras.initializers.TruncatedNormal(stddev=0.02),
}
return encoder_scaffold.EncoderScaffold(
num_hidden_instances=2,
pooled_output_dim=hidden_size,
embedding_cfg=embedding_cfg,
embedding_cls=embedding_inst,
hidden_cfg=hidden_cfg,
dict_outputs=True)
def test_teams_pretrainer(self):
"""Validate that the Keras object can be created."""
vocab_size = 100
test_generator_network = self._get_network(vocab_size)
test_discriminator_network = self._get_network(vocab_size)
# Create a TEAMS trainer with the created network.
candidate_size = 3
teams_trainer_model = teams_pretrainer.TeamsPretrainer(
generator_network=test_generator_network,
discriminator_mws_network=test_discriminator_network,
num_discriminator_task_agnostic_layers=1,
vocab_size=vocab_size,
candidate_size=candidate_size)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
num_token_predictions = 2
sequence_length = 128
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_positions = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
lm_ids = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
inputs = {
'input_word_ids': word_ids,
'input_mask': mask,
'input_type_ids': type_ids,
'masked_lm_positions': lm_positions,
'masked_lm_ids': lm_ids
}
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = teams_trainer_model(inputs)
lm_outs = outputs['lm_outputs']
disc_rtd_logits = outputs['disc_rtd_logits']
disc_rtd_label = outputs['disc_rtd_label']
disc_mws_logits = outputs['disc_mws_logits']
disc_mws_label = outputs['disc_mws_label']
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
expected_disc_rtd_logits_shape = [None, sequence_length]
expected_disc_rtd_label_shape = [None, sequence_length]
expected_disc_disc_mws_logits_shape = [
None, num_token_predictions, candidate_size
]
expected_disc_disc_mws_label_shape = [None, num_token_predictions]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
self.assertAllEqual(expected_disc_rtd_logits_shape,
disc_rtd_logits.shape.as_list())
self.assertAllEqual(expected_disc_rtd_label_shape,
disc_rtd_label.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_logits_shape,
disc_mws_logits.shape.as_list())
self.assertAllEqual(expected_disc_disc_mws_label_shape,
disc_mws_label.shape.as_list())
def test_teams_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
vocab_size = 100
test_generator_network = self._get_network(vocab_size)
test_discriminator_network = self._get_network(vocab_size)
# Create a TEAMS trainer with the created network.
teams_trainer_model = teams_pretrainer.TeamsPretrainer(
generator_network=test_generator_network,
discriminator_mws_network=test_discriminator_network,
num_discriminator_task_agnostic_layers=2,
vocab_size=vocab_size,
candidate_size=2)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
mask = tf.constant([[1, 1, 1], [1, 0, 0]], dtype=tf.int32)
type_ids = tf.constant([[1, 1, 1], [2, 2, 2]], dtype=tf.int32)
lm_positions = tf.constant([[0, 1], [0, 2]], dtype=tf.int32)
lm_ids = tf.constant([[10, 20], [20, 30]], dtype=tf.int32)
inputs = {
'input_word_ids': word_ids,
'input_mask': mask,
'input_type_ids': type_ids,
'masked_lm_positions': lm_positions,
'masked_lm_ids': lm_ids
}
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_ = teams_trainer_model(inputs)
def test_serialize_deserialize(self):
"""Validate that the TEAMS trainer can be serialized and deserialized."""
vocab_size = 100
test_generator_network = self._get_network(vocab_size)
test_discriminator_network = self._get_network(vocab_size)
# Create a TEAMS trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
teams_trainer_model = teams_pretrainer.TeamsPretrainer(
generator_network=test_generator_network,
discriminator_mws_network=test_discriminator_network,
num_discriminator_task_agnostic_layers=2,
vocab_size=vocab_size,
candidate_size=2)
# Create another TEAMS trainer via serialization and deserialization.
config = teams_trainer_model.get_config()
new_teams_trainer_model = teams_pretrainer.TeamsPretrainer.from_config(
config)
# Validate that the config can be forced to JSON.
_ = new_teams_trainer_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(teams_trainer_model.get_config(),
new_teams_trainer_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TEAMS pretraining task (Joint Masked LM, Replaced Token Detection and )."""
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_pretrainer
@dataclasses.dataclass
class TeamsPretrainTaskConfig(cfg.TaskConfig):
"""The model config."""
model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
def _get_generator_hidden_layers(discriminator_network, num_hidden_layers,
num_shared_layers):
if num_shared_layers <= 0:
num_shared_layers = 0
hidden_layers = []
else:
hidden_layers = discriminator_network.hidden_layers[:num_shared_layers]
for _ in range(num_shared_layers, num_hidden_layers):
hidden_layers.append(layers.Transformer)
return hidden_layers
def _build_pretrainer(
config: teams.TeamsPretrainerConfig) -> teams_pretrainer.TeamsPretrainer:
"""Instantiates TeamsPretrainer from the config."""
generator_encoder_cfg = config.generator
discriminator_encoder_cfg = config.discriminator
discriminator_network = teams.get_encoder(discriminator_encoder_cfg)
# Copy discriminator's embeddings to generator for easier model serialization.
hidden_layers = _get_generator_hidden_layers(
discriminator_network, generator_encoder_cfg.num_layers,
config.num_shared_generator_hidden_layers)
if config.tie_embeddings:
generator_network = teams.get_encoder(
generator_encoder_cfg,
embedding_network=discriminator_network.embedding_network,
hidden_layers=hidden_layers)
else:
generator_network = teams.get_encoder(
generator_encoder_cfg, hidden_layers=hidden_layers)
return teams_pretrainer.TeamsPretrainer(
generator_network=generator_network,
discriminator_mws_network=discriminator_network,
num_discriminator_task_agnostic_layers=config
.num_discriminator_task_agnostic_layers,
vocab_size=generator_encoder_cfg.vocab_size,
candidate_size=config.candidate_size,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range))
@task_factory.register_task_cls(TeamsPretrainTaskConfig)
class TeamsPretrainTask(base_task.Task):
"""TEAMS Pretrain Task (Masked LM + RTD + MWS)."""
def build_model(self):
return _build_pretrainer(self.task_config.model)
def build_losses(self,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
with tf.name_scope('TeamsPretrainTask/losses'):
metrics = dict([(metric.name, metric) for metric in metrics])
# Generator MLM loss.
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_outputs'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['masked_lm_loss'].update_state(mlm_loss)
weight = self.task_config.model.generator_loss_weight
total_loss = weight * mlm_loss
# Discriminator RTD loss.
rtd_logits = model_outputs['disc_rtd_logits']
rtd_labels = tf.cast(model_outputs['disc_rtd_label'], tf.float32)
input_mask = tf.cast(labels['input_mask'], tf.float32)
rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=rtd_logits, labels=rtd_labels)
rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss)
rtd_denominator = tf.reduce_sum(input_mask)
rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator)
metrics['replaced_token_detection_loss'].update_state(rtd_loss)
weight = self.task_config.model.discriminator_rtd_loss_weight
total_loss = total_loss + weight * rtd_loss
# Discriminator MWS loss.
mws_logits = model_outputs['disc_mws_logits']
mws_labels = model_outputs['disc_mws_label']
mws_loss = tf.keras.losses.sparse_categorical_crossentropy(
mws_labels, mws_logits, from_logits=True)
mws_numerator_loss = tf.reduce_sum(mws_loss * lm_label_weights)
mws_denominator_loss = tf.reduce_sum(lm_label_weights)
mws_loss = tf.math.divide_no_nan(mws_numerator_loss, mws_denominator_loss)
metrics['multiword_selection_loss'].update_state(mws_loss)
weight = self.task_config.model.discriminator_mws_loss_weight
total_loss = total_loss + weight * mws_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
metrics['total_loss'].update_state(total_loss)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='masked_lm_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='replaced_token_detection_accuracy'),
tf.keras.metrics.Mean(name='replaced_token_detection_loss'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='multiword_selection_accuracy'),
tf.keras.metrics.Mean(name='multiword_selection_loss'),
tf.keras.metrics.Mean(name='total_loss'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
with tf.name_scope('TeamsPretrainTask/process_metrics'):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_outputs'],
labels['masked_lm_weights'])
if 'replaced_token_detection_accuracy' in metrics:
rtd_logits_expanded = tf.expand_dims(model_outputs['disc_rtd_logits'],
-1)
rtd_full_logits = tf.concat(
[-1.0 * rtd_logits_expanded, rtd_logits_expanded], -1)
metrics['replaced_token_detection_accuracy'].update_state(
model_outputs['disc_rtd_label'], rtd_full_logits,
labels['input_mask'])
if 'multiword_selection_accuracy' in metrics:
metrics['multiword_selection_accuracy'].update_state(
model_outputs['disc_mws_label'], model_outputs['disc_mws_logits'],
labels['masked_lm_weights'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = model(inputs, training=False)
loss = self.build_losses(
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for teams_task."""
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.projects.teams import teams
from official.nlp.projects.teams import teams_task
class TeamsPretrainTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((1, 1), (0, 1), (0, 0), (1, 0))
def test_task(self, num_shared_hidden_layers,
num_task_agnostic_layers):
config = teams_task.TeamsPretrainTaskConfig(
model=teams.TeamsPretrainerConfig(
generator=encoders.BertEncoderConfig(
vocab_size=30522, num_layers=2),
discriminator=encoders.BertEncoderConfig(
vocab_size=30522, num_layers=2),
num_shared_generator_hidden_layers=num_shared_hidden_layers,
num_discriminator_task_agnostic_layers=num_task_agnostic_layers,
),
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = teams_task.TeamsPretrainTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
......@@ -48,15 +48,15 @@ def _flatten_dims(tensor: tf.Tensor,
rank = tensor.shape.rank
if rank is None:
raise ValueError('Static rank of `tensor` must be known.')
if first_dim < 0:
if first_dim < 0: # pytype: disable=unsupported-operands
first_dim += rank
if first_dim < 0 or first_dim >= rank:
if first_dim < 0 or first_dim >= rank: # pytype: disable=unsupported-operands
raise ValueError('`first_dim` out of bounds for `tensor` rank.')
if last_dim < 0:
if last_dim < 0: # pytype: disable=unsupported-operands
last_dim += rank
if last_dim < 0 or last_dim >= rank:
if last_dim < 0 or last_dim >= rank: # pytype: disable=unsupported-operands
raise ValueError('`last_dim` out of bounds for `tensor` rank.')
if first_dim > last_dim:
if first_dim > last_dim: # pytype: disable=unsupported-operands
raise ValueError('`first_dim` must not be larger than `last_dim`.')
# Try to calculate static flattened dim size if all input sizes to flatten
......
......@@ -13,12 +13,19 @@
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
import os
import time
from typing import Dict, List, Optional, Text, Union
from absl import logging
import tensorflow as tf
from official.core import export_base
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
def export(export_module: export_base.ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
......@@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule,
The savedmodel directory path.
"""
save_options = tf.saved_model.SaveOptions(function_aliases={
"tpu_candidate": export_module.serve,
'tpu_candidate': export_module.serve,
})
return export_base.export(export_module, function_keys, export_savedmodel_dir,
checkpoint_path, timestamped, save_options)
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = os.path.join(export_dir_base, str(timestamp))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
......@@ -80,11 +80,10 @@ class SentencePrediction(export_base.ExportModule):
lower_case=params.lower_case,
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
def _serve_tokenized_input(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> tf.Tensor:
if input_type_ids is None:
# Requires CLS token is the first token of inputs.
input_type_ids = tf.zeros_like(input_word_ids)
......@@ -97,7 +96,26 @@ class SentencePrediction(export_base.ExportModule):
input_word_ids=input_word_ids,
input_mask=input_mask,
input_type_ids=input_type_ids)
return dict(outputs=self.inference_step(inputs))
return self.inference_step(inputs)
@tf.function
def serve(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids))
@tf.function
def serve_probability(self,
input_word_ids,
input_mask=None,
input_type_ids=None) -> Dict[str, tf.Tensor]:
return dict(
outputs=tf.nn.softmax(
self._serve_tokenized_input(input_word_ids, input_mask,
input_type_ids)))
@tf.function
def serve_examples(self, inputs) -> Dict[str, tf.Tensor]:
......
......@@ -13,10 +13,10 @@
# limitations under the License.
"""Sentence prediction (classification) task."""
import dataclasses
from typing import List, Union, Optional
from absl import logging
import dataclasses
import numpy as np
import orbit
from scipy import stats
......@@ -140,15 +140,26 @@ class SentencePredictionTask(base_task.Task):
del training
if self.task_config.model.num_classes == 1:
metrics = [tf.keras.metrics.MeanSquaredError()]
elif self.task_config.model.num_classes == 2:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
tf.keras.metrics.AUC(name='auc', curve='PR'),
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels[self.label_field], model_outputs)
if metric.name == 'auc':
# Convert the logit to probability and extract the probability of True..
metric.update_state(
labels[self.label_field],
tf.expand_dims(tf.nn.softmax(model_outputs)[:, 1], axis=1))
if metric.name == 'cls_accuracy':
metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels[self.label_field], model_outputs)
......
......@@ -13,11 +13,11 @@
# limitations under the License.
"""Defines the translation task."""
import dataclasses
import os
from typing import Optional
from absl import logging
import dataclasses
import sacrebleu
import tensorflow as tf
import tensorflow_text as tftxt
......
......@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_task(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_no_sentencepiece_path(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
sentencepeice_model_prefix)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_evaluation(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder(),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1),
padded_decode=False,
decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig(
......
......@@ -27,9 +27,15 @@ from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'pretrain_steps',
default=None,
help='The number of total training steps for the pretraining job.')
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
......@@ -40,27 +46,33 @@ def main(_):
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
if FLAGS.mode == 'continuous_train_and_eval':
continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
else:
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case
# of GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
# when dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(
params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
......
......@@ -81,6 +81,7 @@ setup(
'official.pip_package*',
'official.benchmark*',
'official.colab*',
'official.recommendation.ranking.data.preprocessing*',
]),
exclude_package_data={
'': ['*_test.py',],
......
# BASNet: Boundary-Aware Salient Object Detection
This repository is the unofficial implementation of the following paper. Please
see the paper
[BASNet: Boundary-Aware Salient Object Detection](https://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html)
for more details.
## Requirements
[![TensorFlow 2.4](https://img.shields.io/badge/TensorFlow-2.4-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.4.0)
[![Python 3.7](https://img.shields.io/badge/Python-3.7-3776AB)](https://www.python.org/downloads/release/python-379/)
## Train
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=train \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
```
## Test
```shell
$ python3 train.py \
--experiment=basnet_duts \
--mode=eval \
--model_dir=$MODEL_DIR \
--config_file=./configs/experiments/basnet_dut_gpu.yaml
--params_override='runtime.num_gpus=1, runtime.distribution_strategy=one_device, task.model.input_size=[256, 256, 3]'
```
## Results
Dataset | maxF<sub>β</sub> | relaxF<sub>β</sub> | MAE
:--------- | :--------------- | :------------------- | -------:
DUTS-TE | 0.865 | 0.793 | 0.046
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BASNet configuration definition."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size: List[int] = dataclasses.field(default_factory=list)
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 1000
cycle_length: int = 10
resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_rand_hflip: bool = True
file_type: str = 'tfrecord'
@dataclasses.dataclass
class BASNetModel(hyperparams.Config):
"""BASNet model config."""
input_size: List[int] = dataclasses.field(default_factory=list)
use_bias: bool = False
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.1
ignore_label: int = 0 # will be treated as background
l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True
@dataclasses.dataclass
class BASNetTask(cfg.TaskConfig):
"""The model config."""
model: BASNetModel = BASNetModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'backbone' # all, backbone, and/or decoder
@exp_factory.register_config_factory('basnet')
def basnet() -> cfg.ExperimentConfig:
"""BASNet general."""
return cfg.ExperimentConfig(
task=BASNetModel(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
# DUTS Dataset
DUTS_TRAIN_EXAMPLES = 10553
DUTS_VAL_EXAMPLES = 5019
DUTS_INPUT_PATH_BASE_TR = 'DUTS_DATASET'
DUTS_INPUT_PATH_BASE_VAL = 'DUTS_DATASET'
@exp_factory.register_config_factory('basnet_duts')
def basnet_duts() -> cfg.ExperimentConfig:
"""Image segmentation on duts with basnet."""
train_batch_size = 64
eval_batch_size = 16
steps_per_epoch = DUTS_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=BASNetTask(
model=BASNetModel(
input_size=[None, None, 3],
use_bias=True,
norm_activation=common.NormActivation(
activation='relu',
norm_momentum=0.99,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=0),
train_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_TR,
'tf_record_train'),
file_type='tfrecord',
crop_size=[224, 224],
output_size=[256, 256],
is_training=True,
global_batch_size=train_batch_size,
),
validation_data=DataConfig(
input_path=os.path.join(DUTS_INPUT_PATH_BASE_VAL,
'tf_record_test'),
file_type='tfrecord',
output_size=[256, 256],
is_training=False,
global_batch_size=eval_batch_size,
),
init_checkpoint='gs://cloud-basnet-checkpoints/basnet_encoder_imagenet/ckpt-340306',
init_checkpoint_modules='backbone'
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=DUTS_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_1': 0.9,
'beta_2': 0.999,
'epsilon': 1e-8,
}
},
'learning_rate': {
'type': 'constant',
'constant': {'learning_rate': 0.001}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basnet configs."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.projects.basnet.configs import basnet as exp_cfg
class BASNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('basnet_duts',))
def test_basnet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.BASNetTask)
self.assertIsInstance(config.task.model,
exp_cfg.BASNetModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
num_gpus: 8
task:
train_data:
dtype: 'float32'
validation_data:
resize_eval_groundtruth: true
dtype: 'float32'
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics for BASNet.
The MAE and maxFscore implementations are a modified version of
https://github.com/xuebinqin/Binary-Segmentation-Evaluation-Tool
"""
import numpy as np
import scipy.signal
class MAE:
"""Mean Absolute Error(MAE) metric for basnet."""
def __init__(self):
"""Constructs MAE metric class."""
self.reset_states()
@property
def name(self):
return 'MAE'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
average_mae: average MAE with float numpy.
"""
mae_total = 0.0
for (true, pred) in zip(self._groundtruths, self._predictions):
# Computes MAE
mae = self._compute_mae(true, pred)
mae_total += mae
average_mae = mae_total / len(self._groundtruths)
return average_mae
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_mae(self, true, pred):
h, w = true.shape[0], true.shape[1]
mask1 = self._mask_normalize(true)
mask2 = self._mask_normalize(pred)
sum_error = np.sum(np.absolute((mask1.astype(float) - mask2.astype(float))))
mae_error = sum_error/(float(h)*float(w)+1e-8)
return mae_error
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
class MaxFscore:
"""Maximum F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class."""
self.reset_states()
@property
def name(self):
return 'MaxFScore'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
f_max: maximum F-score value.
"""
mybins = np.arange(0, 256)
beta = 0.3
precisions = np.zeros((len(self._groundtruths), len(mybins)-1))
recalls = np.zeros((len(self._groundtruths), len(mybins)-1))
for i, (true, pred) in enumerate(zip(self._groundtruths,
self._predictions)):
# Compute F-score
true = self._mask_normalize(true) * 255.0
pred = self._mask_normalize(pred) * 255.0
pre, rec = self._compute_pre_rec(true, pred, mybins=np.arange(0, 256))
precisions[i, :] = pre
recalls[i, :] = rec
precisions = np.sum(precisions, 0) / (len(self._groundtruths) + 1e-8)
recalls = np.sum(recalls, 0) / (len(self._groundtruths) + 1e-8)
f = (1 + beta) * precisions * recalls / (beta * precisions + recalls + 1e-8)
f_max = np.max(f)
f_max = f_max.astype(np.float32)
return f_max
def _mask_normalize(self, mask):
return mask / (np.amax(mask) + 1e-8)
def _compute_pre_rec(self, true, pred, mybins=np.arange(0, 256)):
"""Computes relaxed precision and recall."""
# pixel number of ground truth foreground regions
gt_num = true[true > 128].size
# mask predicted pixel values in the ground truth foreground region
pp = pred[true > 128]
# mask predicted pixel values in the ground truth bacground region
nn = pred[true <= 128]
pp_hist, _ = np.histogram(pp, bins=mybins)
nn_hist, _ = np.histogram(nn, bins=mybins)
pp_hist_flip = np.flipud(pp_hist)
nn_hist_flip = np.flipud(nn_hist)
pp_hist_flip_cum = np.cumsum(pp_hist_flip)
nn_hist_flip_cum = np.cumsum(nn_hist_flip)
precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-8
) # TP/(TP+FP)
recall = pp_hist_flip_cum / (gt_num + 1e-8) # TP/(TP+FN)
precision[np.isnan(precision)] = 0.0
recall[np.isnan(recall)] = 0.0
pre_len = len(precision)
rec_len = len(recall)
return np.reshape(precision, (pre_len)), np.reshape(recall, (rec_len))
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of signle Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
class RelaxedFscore:
"""Relaxed F-score metric for basnet."""
def __init__(self):
"""Constructs BASNet evaluation class."""
self.reset_states()
@property
def name(self):
return 'RelaxFScore'
def reset_states(self):
"""Resets internal states for a fresh run."""
self._predictions = []
self._groundtruths = []
def result(self):
"""Evaluates segmentation results, and reset_states."""
metric_result = self.evaluate()
# Cleans up the internal variables in order for a fresh eval next time.
self.reset_states()
return metric_result
def evaluate(self):
"""Evaluates with masks from all images.
Returns:
relax_f: relaxed F-score value.
"""
beta = 0.3
rho = 3
relax_fs = np.zeros(len(self._groundtruths))
erode_kernel = np.ones((3, 3))
for i, (true,
pred) in enumerate(zip(self._groundtruths, self._predictions)):
true = self._mask_normalize(true)
pred = self._mask_normalize(pred)
true = np.squeeze(true, axis=-1)
pred = np.squeeze(pred, axis=-1)
# binary saliency mask (S_bw), threshold 0.5
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
# compute eroded binary mask (S_erd) of S_bw
pred_erd = self._compute_erosion(pred, erode_kernel)
pred_xor = np.logical_xor(pred_erd, pred)
# convert True/False to 1/0
pred_xor = pred_xor * 1
# same method for ground truth
true[true >= 0.5] = 1
true[true < 0.5] = 0
true_erd = self._compute_erosion(true, erode_kernel)
true_xor = np.logical_xor(true_erd, true)
true_xor = true_xor * 1
pre, rec = self._compute_relax_pre_rec(true_xor, pred_xor, rho)
relax_fs[i] = (1 + beta) * pre * rec / (beta * pre + rec + 1e-8)
relax_f = np.sum(relax_fs, 0) / (len(self._groundtruths) + 1e-8)
relax_f = relax_f.astype(np.float32)
return relax_f
def _mask_normalize(self, mask):
return mask/(np.amax(mask)+1e-8)
def _compute_erosion(self, mask, kernel):
kernel_full = np.sum(kernel)
mask_erd = scipy.signal.convolve2d(mask, kernel, mode='same')
mask_erd[mask_erd < kernel_full] = 0
mask_erd[mask_erd == kernel_full] = 1
return mask_erd
def _compute_relax_pre_rec(self, true, pred, rho):
"""Computes relaxed precision and recall."""
kernel = np.ones((2 * rho - 1, 2 * rho - 1))
map_zeros = np.zeros_like(pred)
map_ones = np.ones_like(pred)
pred_filtered = scipy.signal.convolve2d(pred, kernel, mode='same')
# True positive for relaxed precision
relax_pre_tp = np.where((true == 1) & (pred_filtered > 0), map_ones,
map_zeros)
true_filtered = scipy.signal.convolve2d(true, kernel, mode='same')
# True positive for relaxed recall
relax_rec_tp = np.where((pred == 1) & (true_filtered > 0), map_ones,
map_zeros)
return np.sum(relax_pre_tp) / np.sum(pred), np.sum(relax_rec_tp) / np.sum(
true)
def _convert_to_numpy(self, groundtruths, predictions):
"""Converts tesnors to numpy arrays."""
numpy_groundtruths = groundtruths.numpy()
numpy_predictions = predictions.numpy()
return numpy_groundtruths, numpy_predictions
def update_state(self, groundtruths, predictions):
"""Update segmentation results and groundtruth data.
Args:
groundtruths : Tuple of single Tensor [batch, width, height, 1],
groundtruth masks. range [0, 1]
predictions : Tuple of single Tensor [batch, width, height, 1],
predicted masks. range [0, 1]
"""
groundtruths, predictions = self._convert_to_numpy(groundtruths[0],
predictions[0])
for (true, pred) in zip(groundtruths, predictions):
self._groundtruths.append(true)
self._predictions.append(pred)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment