Commit b1025b3b authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into fasterrcnn_fpn_keras_feature_extractor

parents 69ce1c45 e9df75ab
......@@ -32,9 +32,9 @@ from official.vision.segmentation import unet_model as unet_model_lib
UNET3D_MIN_ACCURACY = 0.90
UNET3D_MAX_ACCURACY = 0.98
UNET_TRAINING_FILES = 'unet_training_data_files'
UNET_EVAL_FILES = 'unet_eval_data_files'
UNET_MODEL_CONFIG_FILE = 'unet_model_config'
UNET_TRAINING_FILES = 'gs://mlcompass-data/unet3d/train_data/*'
UNET_EVAL_FILES = 'gs://mlcompass-data/unet3d/eval_data/*'
UNET_MODEL_CONFIG_FILE = 'gs://mlcompass-data/unet3d/config/unet_config.yaml'
FLAGS = flags.FLAGS
......
This diff is collapsed.
......@@ -14,15 +14,18 @@
# limitations under the License.
# ==============================================================================
"""Defines the base task abstraction."""
import abc
import functools
from typing import Any, Callable, Optional
import six
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module):
"""A single-replica view of training procedure.
......@@ -54,14 +57,13 @@ class Task(tf.Module):
"""
pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
Returns:
A model instance.
"""
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self,
model: tf.keras.Model,
......@@ -98,6 +100,7 @@ class Task(tf.Module):
model.test_step = functools.partial(validation_step, model=model)
return model
@abc.abstractmethod
def build_inputs(self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
......@@ -112,20 +115,19 @@ class Task(tf.Module):
Returns:
A nested structure of per-replica input functions.
"""
pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor:
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
features: optional feature/labels tensors.
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, features
del model_outputs, labels
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
......@@ -139,29 +141,29 @@ class Task(tf.Module):
del training
return []
def process_metrics(self, metrics, labels, outputs):
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, outputs)
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs):
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, outputs)
compiled_metrics.update_state(labels, model_outputs)
def train_step(self,
inputs,
......@@ -187,7 +189,7 @@ class Task(tf.Module):
outputs = model(features, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
......@@ -231,7 +233,7 @@ class Task(tf.Module):
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
......@@ -250,11 +252,44 @@ _REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file.
def register_task_cls(task_config: cfg.TaskConfig) -> Task:
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_TASK_CLS, task_config)
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task_cls(task_config: cfg.TaskConfig) -> Task:
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
......@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config):
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True
train_tf_function: bool = True
......@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config):
summary_interval: int = 1000
checkpoint_interval: int = 1000
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
@dataclasses.dataclass
......
......@@ -230,9 +230,10 @@ def pretrain_model(bert_config,
initializer=initializer,
output='predictions')
lm_output, sentence_output = pretrainer_model(
outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
......
......@@ -111,6 +111,7 @@ def run_customized_training_loop(
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
num_eval_per_epoch=1,
steps_per_loop=None,
epochs=1,
eval_input_fn=None,
......@@ -144,6 +145,7 @@ def run_customized_training_loop(
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
num_eval_per_epoch: Number of evaluations per epoch.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
......@@ -158,16 +160,17 @@ def run_customized_training_loop(
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during
training. Note that some metrics may be missing from `logs`.
training. More specifically, `on_train_begin(), on_train_end(),
on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
`on_epoch_end()` methods are invoked during training.
Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by
`model_fn` into checkpoint files. The name of intermediate checkpoint
file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt;
if None, `sub_model` will not be exported as checkpoint.
checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model`
will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision,
......@@ -177,10 +180,10 @@ def run_customized_training_loop(
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced.
With mixed precision training, the pre_allreduce_allbacks will be
applied on scaled_gradients. Default is no callbacks.
Only used when explicit_allreduce=True.
invoked in the list order and before gradients are allreduced. With
mixed precision training, the pre_allreduce_allbacks will be applied on
scaled_gradients. Default is no callbacks. Only used when
explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
......@@ -208,6 +211,8 @@ def run_customized_training_loop(
required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
]
steps_between_evals = int(steps_per_epoch / num_eval_per_epoch)
if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'`steps_per_epoch` and `train_input_fn` are required '
......@@ -216,17 +221,17 @@ def run_customized_training_loop(
if tf.config.list_logical_devices('TPU'):
# One can't fully utilize a TPU with steps_per_loop=1, so in this case
# default users to a more useful value.
steps_per_loop = min(1000, steps_per_epoch)
steps_per_loop = min(1000, steps_between_evals)
else:
steps_per_loop = 1
logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
steps_per_loop)
if steps_per_loop > steps_per_epoch:
if steps_per_loop > steps_between_evals:
logging.warning(
'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as'
' steps_per_loop.', steps_per_loop, steps_per_epoch)
steps_per_loop = steps_per_epoch
' steps_between_evals: %d, we will use steps_between_evals as'
' steps_per_loop.', steps_per_loop, steps_between_evals)
steps_per_loop = steps_between_evals
assert tf.executing_eagerly()
if run_eagerly:
......@@ -242,12 +247,9 @@ def run_customized_training_loop(
raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
......@@ -260,6 +262,9 @@ def run_customized_training_loop(
raise ValueError('sub_model_export_name is specified as %s, but '
'sub_model is None.' % sub_model_export_name)
callback_list = tf.keras.callbacks.CallbackList(
callbacks=custom_callbacks, model=model)
optimizer = model.optimizer
if init_checkpoint:
......@@ -270,8 +275,7 @@ def run_customized_training_loop(
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
......@@ -440,8 +444,7 @@ def run_customized_training_loop(
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
'Checkpoint file %s found and restoring from '
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed')
......@@ -449,9 +452,12 @@ def run_customized_training_loop(
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps:
logs = {}
callback_list.on_train_begin()
while current_step < total_training_steps and not model.stop_training:
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1)
callback_list.on_epoch_begin(
int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
......@@ -461,7 +467,7 @@ def run_customized_training_loop(
callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
steps = steps_to_run(current_step, steps_between_evals, steps_per_loop)
if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop
......@@ -470,11 +476,9 @@ def run_customized_training_loop(
train_single_step(train_iterator)
else:
# Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric)
current_step += steps
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
......@@ -492,8 +496,7 @@ def run_customized_training_loop(
'learning_rate',
optimizer.learning_rate(current_step),
step=current_step)
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
......@@ -501,7 +504,11 @@ def run_customized_training_loop(
summary_writer.flush()
logging.info(training_status)
if current_step % steps_per_epoch == 0:
# If no need for evaluation, we only call on_batch_end with train_loss,
# this is to ensure we get granular global_step/sec on Tensorboard.
if current_step % steps_between_evals:
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
else:
# Save a submodel with the step in the file name after each epoch.
if sub_model_export_name:
_save_checkpoint(
......@@ -514,7 +521,6 @@ def run_customized_training_loop(
if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step,
......@@ -523,7 +529,14 @@ def run_customized_training_loop(
eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics:
metric.reset_states()
# We add train_loss here rather than call on_batch_end twice to make
# sure that no duplicated values are generated.
logs['loss'] = train_loss
callback_list.on_batch_end(current_step - 1, logs)
# Calls on_epoch_end after each real epoch ends to prevent mis-calculation
# of training steps.
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
if sub_model_export_name:
......@@ -532,14 +545,11 @@ def run_customized_training_loop(
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
......@@ -557,4 +567,6 @@ def run_customized_training_loop(
if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir)
callback_list.on_train_end()
return model
......@@ -258,6 +258,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir,
steps_per_epoch=20,
num_eval_per_epoch=4,
steps_per_loop=10,
epochs=2,
train_input_fn=input_fn,
......@@ -269,14 +270,15 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2])
self.assertEqual(list(epoch_ends), [1, 2, 2])
for info in epoch_end_infos:
self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin,
[(0, {}), (10, {}), (20, {}), (30, {})])
self.assertEqual(callback.batch_begin, [(0, {}), (5, {}), (10, {}),
(15, {}), (20, {}), (25, {}),
(30, {}), (35, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39])
self.assertEqual(list(batch_ends), [4, 9, 14, 19, 24, 29, 34, 39])
for info in batch_end_infos:
self.assertIn('loss', info)
......
......@@ -45,6 +45,9 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to
approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks.
......
......@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
......
......@@ -25,91 +25,74 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.Model):
class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method.
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the
embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
embedding_table: The embedding table of the targets.
activation: The activation, if any, for the dense layer.
initializer: The intializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
input_width,
num_predictions,
source_network,
embedding_table=None,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
name='cls/predictions',
**kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if embedding_table is None:
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
masked_lm_positions = tf.keras.layers.Input(
shape=(num_predictions,), name='masked_lm_positions', dtype=tf.int32)
masked_lm_input = tf.keras.layers.Lambda(
lambda x: self._gather_indexes(x[0], x[1]))(
[sequence_data, masked_lm_positions])
lm_data = (
tf.keras.layers.Dense(
hidden_size,
activation=activation,
kernel_initializer=initializer,
name='cls/predictions/transform/dense')(masked_lm_input))
lm_data = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='cls/predictions/transform/LayerNorm')(
lm_data)
lm_data = tf.keras.layers.Lambda(
lambda x: tf.matmul(x, embedding_table, transpose_b=True))(
lm_data)
logits = Bias(
initializer=tf.keras.initializers.Zeros(),
name='cls/predictions/output_bias')(
lm_data)
# We can't use the standard Keras reshape layer here, since it expects
# the input and output batch size to be the same.
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))
self.logits = reshape_layer(logits)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
def build(self, input_shape):
self._vocab_size, hidden_size = self.embedding_table.shape
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
super(MaskedLM, self).__init__(
inputs=[sequence_data, masked_lm_positions],
outputs=output_tensors,
**kwargs)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized at this '
'time. Please use it only in Layers or '
'functionally subclassed Models/Networks.')
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
......@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model):
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
# Temporary until we can create a Dense layer that ties the embedding.
class Bias(tf.keras.layers.Layer):
"""Adds a bias term to an input."""
def __init__(self,
initializer='zeros',
regularizer=None,
constraint=None,
activation=None,
**kwargs):
super(Bias, self).__init__(**kwargs)
self._initializer = tf.keras.initializers.get(initializer)
self._regularizer = tf.keras.regularizers.get(regularizer)
self._constraint = tf.keras.constraints.get(constraint)
self._activation = tf.keras.activations.get(activation)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self._bias = self.add_weight(
'bias',
shape=input_shape[1:],
initializer=self._initializer,
regularizer=self._regularizer,
constraint=self._constraint,
dtype=self._dtype,
trainable=True)
super(Bias, self).build(input_shape)
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
'initializer': tf.keras.initializers.serialize(self._initializer),
'regularizer': tf.keras.regularizers.serialize(self._regularizer),
'constraint': tf.keras.constraints.serialize(self._constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
outputs = tf.nn.bias_add(inputs, self._bias)
if self._activation is not None:
return self._activation(outputs) # pylint: disable=not-callable
else:
return outputs
......@@ -23,7 +23,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import masked_lm
from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import transformer_encoder
......@@ -32,11 +32,10 @@ from official.nlp.modeling.networks import transformer_encoder
@keras_parameterized.run_all_keras_modes
class MaskedLMTest(keras_parameterized.TestCase):
def create_network(self,
def create_layer(self,
vocab_size,
sequence_length,
hidden_size,
num_predictions,
output='predictions',
xformer_stack=None):
# First, create a transformer stack that we can use to get the LM's
......@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size,
num_attention_heads=4,
)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack.
test_network = masked_lm.MaskedLM(
num_predictions=num_predictions,
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
test_layer = masked_lm.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
return test_network
return test_layer
def test_network_creation(self):
def test_layer_creation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions=masked_positions)
expected_output_shape = [None, num_predictions, vocab_size]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation_with_internal_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(test_network.inputs, test_network.logits)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_external_logits(self):
def test_layer_invocation_with_external_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
......@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size,
num_attention_heads=4,
)
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack,
output='predictions')
logit_network = self.create_network(
logit_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack,
output='logits')
logit_network.set_weights(test_network.get_weights())
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
logit_output = logit_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_lm_positions]),
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
logit_output = logit_layer(lm_input_tensor, masked_positions)
logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
logit_layer.set_weights(test_layer.get_weights())
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
logit_output)
# Invoke the masked LM on some fake data to make sure there are no runtime
......@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase):
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
sequence_length, size=(batch_size, num_predictions))
# ref_outputs = model.predict([lm_input_data, masked_position_data])
# outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs = model([lm_input_data, masked_position_data])
outputs = logits_model([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, ref_outputs.shape)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
self.assertAllClose(ref_outputs, outputs)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation(self):
def test_layer_invocation(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
hidden_size=hidden_size)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
output = test_layer(lm_input_tensor, masked_positions)
model = tf.keras.Model([lm_input_tensor, masked_positions], output)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
......@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_network(
vocab_size=8,
sequence_length=8,
hidden_size=8,
num_predictions=8,
output='bad')
_ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
if __name__ == '__main__':
......
......@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy
......@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
_ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack.
test_network = networks.MaskedLM(
num_predictions=num_predictions,
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
......
......@@ -25,6 +25,7 @@ from typing import List, Optional
import gin
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
activation: The activation (if any) to use in the masked LM network. If
None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
......@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model):
dtype=tf.int32)
inputs.append(masked_lm_positions)
self.masked_lm = networks.MaskedLM(
num_predictions=num_token_predictions,
input_width=sequence_output.shape[-1],
source_network=network,
if embedding_table is None:
embedding_table = self.encoder.get_embedding_table()
self.masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
output=output,
name='masked_lm')
lm_outputs = self.masked_lm([sequence_output, masked_lm_positions])
name='cls/predictions')
lm_outputs = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification(
input_width=cls_output.shape[-1],
......@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model):
sentence_outputs = self.classification(cls_output)
super(BertPretrainer, self).__init__(
inputs=inputs, outputs=[lm_outputs, sentence_outputs], **kwargs)
inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
def get_config(self):
return self._config
......@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
mlm_activation: The activation (if any) to use in the masked LM network. If
None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
......@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model):
outputs = dict()
if num_masked_tokens > 0:
self.masked_lm = networks.MaskedLM(
num_predictions=num_masked_tokens,
input_width=sequence_output.shape[-1],
source_network=self.encoder_network,
self.masked_lm = layers.MaskedLM(
embedding_table=self.encoder_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
name='masked_lm')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1])
name='cls/predictions')
masked_lm_positions = tf.keras.layers.Input(
shape=(num_masked_tokens,),
name='masked_lm_positions',
dtype=tf.int32)
inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm(
[sequence_output, masked_lm_positions])
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output)
......
......@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
masked_lm_positions = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs, cls_outs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
outputs = bert_trainer_model(
[word_ids, mask, type_ids, masked_lm_positions])
# Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size]
expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list())
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
self.assertAllEqual(expected_lm_shape, outputs['masked_lm'].shape.as_list())
self.assertAllEqual(expected_classification_shape,
outputs['classification'].shape.as_list())
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
......@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_, _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
_ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
......@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
......
......@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
......
......@@ -16,7 +16,6 @@
from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.masked_lm import MaskedLM
from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
......@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self,
features,
labels,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=features['masked_lm_ids'],
labels=labels['masked_lm_ids'],
predictions=lm_output,
weights=features['masked_lm_weights'])
weights=labels['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features:
if 'next_sentence_labels' in labels:
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32
predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = features['next_sentence_labels']
sentence_labels = labels['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels,
predictions=predictions)
......@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics
def process_metrics(self, metrics, inputs, outputs):
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(inputs['masked_lm_ids'],
outputs['lm_output'],
inputs['masked_lm_weights'])
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_output'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
inputs['next_sentence_labels'], outputs['next_sentence'])
labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
......@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=inputs,
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
......@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
"""
outputs = self.inference_step(inputs, model)
loss = self.build_losses(
features=inputs,
labels=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
......
......@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `pretrain_checkpoint_dir` and `hub_module_url` can
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
pretrain_checkpoint_dir: str = ''
init_checkpoint: str = ''
hub_module_url: str = ''
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0,
......@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task):
def __init__(self, params=cfg.TaskConfig):
super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.pretrain_checkpoint_dir:
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.')
if params.hub_module_url:
......@@ -79,12 +79,11 @@ class SentencePredictionTask(base_task.Task):
else:
return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor:
labels = features
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels,
predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'],
axis=-1))
predictions=tf.nn.log_softmax(
model_outputs['sentence_prediction'], axis=-1))
if aux_losses:
loss += tf.add_n(aux_losses)
......@@ -93,6 +92,7 @@ class SentencePredictionTask(base_task.Task):
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
......@@ -113,22 +113,22 @@ class SentencePredictionTask(base_task.Task):
def build_metrics(self, training=None):
del training
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
]
metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
return metrics
def process_metrics(self, metrics, labels, outputs):
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels, outputs['sentence_prediction'])
metric.update_state(labels, model_outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, outputs):
compiled_metrics.update_state(labels, outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
pretrain_ckpt_dir = self.task_config.pretrain_checkpoint_dir
if not pretrain_ckpt_dir:
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
pretrain2finetune_mapping = {
......@@ -138,10 +138,7 @@ class SentencePredictionTask(base_task.Task):
model.checkpoint_items['sentence_prediction.pooler_dense'],
}
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
latest_pretrain_ckpt = tf.train.latest_checkpoint(pretrain_ckpt_dir)
if latest_pretrain_ckpt is None:
raise FileNotFoundError(
'Cannot find pretrain checkpoint under {}'.format(pretrain_ckpt_dir))
status = ckpt.restore(latest_pretrain_ckpt)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint.')
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
......@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self):
config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
network=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0,
cls_heads=[
bert.ClsHeadConfig(
......@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
......
......@@ -218,7 +218,7 @@ def get_callbacks():
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None)
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback)
if FLAGS.enable_tensorboard:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment