Commit b1025b3b authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into fasterrcnn_fpn_keras_feature_extractor

parents 69ce1c45 e9df75ab
...@@ -32,9 +32,9 @@ from official.vision.segmentation import unet_model as unet_model_lib ...@@ -32,9 +32,9 @@ from official.vision.segmentation import unet_model as unet_model_lib
UNET3D_MIN_ACCURACY = 0.90 UNET3D_MIN_ACCURACY = 0.90
UNET3D_MAX_ACCURACY = 0.98 UNET3D_MAX_ACCURACY = 0.98
UNET_TRAINING_FILES = 'unet_training_data_files' UNET_TRAINING_FILES = 'gs://mlcompass-data/unet3d/train_data/*'
UNET_EVAL_FILES = 'unet_eval_data_files' UNET_EVAL_FILES = 'gs://mlcompass-data/unet3d/eval_data/*'
UNET_MODEL_CONFIG_FILE = 'unet_model_config' UNET_MODEL_CONFIG_FILE = 'gs://mlcompass-data/unet3d/config/unet_config.yaml'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
This diff is collapsed.
...@@ -14,15 +14,18 @@ ...@@ -14,15 +14,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Defines the base task abstraction.""" """Defines the base task abstraction."""
import abc
import functools import functools
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import six
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module): class Task(tf.Module):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -54,14 +57,13 @@ class Task(tf.Module): ...@@ -54,14 +57,13 @@ class Task(tf.Module):
""" """
pass pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates the model architecture.
Returns: Returns:
A model instance. A model instance.
""" """
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self, def compile_model(self,
model: tf.keras.Model, model: tf.keras.Model,
...@@ -98,6 +100,7 @@ class Task(tf.Module): ...@@ -98,6 +100,7 @@ class Task(tf.Module):
model.test_step = functools.partial(validation_step, model=model) model.test_step = functools.partial(validation_step, model=model)
return model return model
@abc.abstractmethod
def build_inputs(self, def build_inputs(self,
params: cfg.DataConfig, params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None): input_context: Optional[tf.distribute.InputContext] = None):
...@@ -112,20 +115,19 @@ class Task(tf.Module): ...@@ -112,20 +115,19 @@ class Task(tf.Module):
Returns: Returns:
A nested structure of per-replica input functions. A nested structure of per-replica input functions.
""" """
pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses. """Standard interface to compute losses.
Args: Args:
features: optional feature/labels tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
""" """
del model_outputs, features del model_outputs, labels
if aux_losses is None: if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)] losses = [tf.constant(0.0, dtype=tf.float32)]
...@@ -139,29 +141,29 @@ class Task(tf.Module): ...@@ -139,29 +141,29 @@ class Task(tf.Module):
del training del training
return [] return []
def process_metrics(self, metrics, labels, outputs): def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API. """Process and update metrics. Called when using custom training loop API.
Args: Args:
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects.
The return of function self.build_metrics. The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
for metric in metrics: for metric in metrics:
metric.update_state(labels, outputs) metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API. """Process and update compiled_metrics. call when using compile/fit API.
Args: Args:
compiled_metrics: the compiled metrics (model.compiled_metrics). compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
compiled_metrics.update_state(labels, outputs) compiled_metrics.update_state(labels, model_outputs)
def train_step(self, def train_step(self,
inputs, inputs,
...@@ -187,7 +189,7 @@ class Task(tf.Module): ...@@ -187,7 +189,7 @@ class Task(tf.Module):
outputs = model(features, training=True) outputs = model(features, training=True)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
...@@ -231,7 +233,7 @@ class Task(tf.Module): ...@@ -231,7 +233,7 @@ class Task(tf.Module):
features, labels = inputs, inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
...@@ -250,11 +252,44 @@ _REGISTERED_TASK_CLS = {} ...@@ -250,11 +252,44 @@ _REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file. # TODO(b/158268740): Move these outside the base class file.
def register_task_cls(task_config: cfg.TaskConfig) -> Task: # TODO(b/158741360): Add type annotations once pytype checks across modules.
"""Register ExperimentConfig factory method.""" def register_task_cls(task_config_cls):
return registry.register(_REGISTERED_TASK_CLS, task_config) """Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task_cls(task_config: cfg.TaskConfig) -> Task: # The user-visible get_task() is defined after classes have been registered.
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config) # TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls return task_cls
...@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config): ...@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TrainerConfig(base_config.Config): class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
"""
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True train_tf_while_loop: bool = True
train_tf_function: bool = True train_tf_function: bool = True
...@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config): ...@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config):
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -230,9 +230,10 @@ def pretrain_model(bert_config, ...@@ -230,9 +230,10 @@ def pretrain_model(bert_config,
initializer=initializer, initializer=initializer,
output='predictions') output='predictions')
lm_output, sentence_output = pretrainer_model( outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer( pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
......
...@@ -111,6 +111,7 @@ def run_customized_training_loop( ...@@ -111,6 +111,7 @@ def run_customized_training_loop(
model_dir=None, model_dir=None,
train_input_fn=None, train_input_fn=None,
steps_per_epoch=None, steps_per_epoch=None,
num_eval_per_epoch=1,
steps_per_loop=None, steps_per_loop=None,
epochs=1, epochs=1,
eval_input_fn=None, eval_input_fn=None,
...@@ -144,6 +145,7 @@ def run_customized_training_loop( ...@@ -144,6 +145,7 @@ def run_customized_training_loop(
steps_per_epoch: Number of steps to run per epoch. At the end of each steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided. if evaluation dataset is provided.
num_eval_per_epoch: Number of evaluations per epoch.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every communication in eager context, training logs are printed every
steps_per_loop. steps_per_loop.
...@@ -158,16 +160,17 @@ def run_customized_training_loop( ...@@ -158,16 +160,17 @@ def run_customized_training_loop(
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_train_begin(), on_train_end(),
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
training. Note that some metrics may be missing from `logs`. `on_epoch_end()` methods are invoked during training.
Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy. should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by sub_model_export_name: If not None, will export `sub_model` returned by
`model_fn` into checkpoint files. The name of intermediate checkpoint `model_fn` into checkpoint files. The name of intermediate checkpoint
file is {sub_model_export_name}_step_{step}.ckpt and the last file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt; checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model`
if None, `sub_model` will not be exported as checkpoint. will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce, explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients(). instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision, default is False. For now, if training using FP16 mixed precision,
...@@ -177,10 +180,10 @@ def run_customized_training_loop( ...@@ -177,10 +180,10 @@ def run_customized_training_loop(
pre_allreduce_callbacks: A list of callback functions that takes gradients pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced. invoked in the list order and before gradients are allreduced. With
With mixed precision training, the pre_allreduce_allbacks will be mixed precision training, the pre_allreduce_allbacks will be applied on
applied on scaled_gradients. Default is no callbacks. scaled_gradients. Default is no callbacks. Only used when
Only used when explicit_allreduce=True. explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
...@@ -208,6 +211,8 @@ def run_customized_training_loop( ...@@ -208,6 +211,8 @@ def run_customized_training_loop(
required_arguments = [ required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
] ]
steps_between_evals = int(steps_per_epoch / num_eval_per_epoch)
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'`steps_per_epoch` and `train_input_fn` are required ' '`steps_per_epoch` and `train_input_fn` are required '
...@@ -216,17 +221,17 @@ def run_customized_training_loop( ...@@ -216,17 +221,17 @@ def run_customized_training_loop(
if tf.config.list_logical_devices('TPU'): if tf.config.list_logical_devices('TPU'):
# One can't fully utilize a TPU with steps_per_loop=1, so in this case # One can't fully utilize a TPU with steps_per_loop=1, so in this case
# default users to a more useful value. # default users to a more useful value.
steps_per_loop = min(1000, steps_per_epoch) steps_per_loop = min(1000, steps_between_evals)
else: else:
steps_per_loop = 1 steps_per_loop = 1
logging.info('steps_per_loop not specified. Using steps_per_loop=%d', logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
steps_per_loop) steps_per_loop)
if steps_per_loop > steps_per_epoch: if steps_per_loop > steps_between_evals:
logging.warning( logging.warning(
'steps_per_loop: %d is specified to be greater than ' 'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as' ' steps_between_evals: %d, we will use steps_between_evals as'
' steps_per_loop.', steps_per_loop, steps_per_epoch) ' steps_per_loop.', steps_per_loop, steps_between_evals)
steps_per_loop = steps_per_epoch steps_per_loop = steps_between_evals
assert tf.executing_eagerly() assert tf.executing_eagerly()
if run_eagerly: if run_eagerly:
...@@ -242,12 +247,9 @@ def run_customized_training_loop( ...@@ -242,12 +247,9 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean( eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
'training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators, # To correctly place the model weights on accelerators,
...@@ -260,6 +262,9 @@ def run_customized_training_loop( ...@@ -260,6 +262,9 @@ def run_customized_training_loop(
raise ValueError('sub_model_export_name is specified as %s, but ' raise ValueError('sub_model_export_name is specified as %s, but '
'sub_model is None.' % sub_model_export_name) 'sub_model is None.' % sub_model_export_name)
callback_list = tf.keras.callbacks.CallbackList(
callbacks=custom_callbacks, model=model)
optimizer = model.optimizer optimizer = model.optimizer
if init_checkpoint: if init_checkpoint:
...@@ -270,8 +275,7 @@ def run_customized_training_loop( ...@@ -270,8 +275,7 @@ def run_customized_training_loop(
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean( train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
'training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else [] eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by # If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation. # both train and evaluation.
...@@ -440,18 +444,20 @@ def run_customized_training_loop( ...@@ -440,18 +444,20 @@ def run_customized_training_loop(
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file: if latest_checkpoint_file:
logging.info( logging.info('Checkpoint file %s found and restoring from '
'Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint_file)
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy() current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps: logs = {}
callback_list.on_train_begin()
while current_step < total_training_steps and not model.stop_training:
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1) callback_list.on_epoch_begin(
int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro # Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round. # training loop. We reset the their values before each round.
...@@ -461,7 +467,7 @@ def run_customized_training_loop( ...@@ -461,7 +467,7 @@ def run_customized_training_loop(
callback_list.on_batch_begin(current_step) callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_between_evals, steps_per_loop)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
...@@ -470,11 +476,9 @@ def run_customized_training_loop( ...@@ -470,11 +476,9 @@ def run_customized_training_loop(
train_single_step(train_iterator) train_single_step(train_iterator)
else: else:
# Converts steps to a Tensor to avoid tf.function retracing. # Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator, train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
current_step += steps current_step += steps
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
...@@ -492,8 +496,7 @@ def run_customized_training_loop( ...@@ -492,8 +496,7 @@ def run_customized_training_loop(
'learning_rate', 'learning_rate',
optimizer.learning_rate(current_step), optimizer.learning_rate(current_step),
step=current_step) step=current_step)
tf.summary.scalar( tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step)
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics: for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric) metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value) training_status += ' %s = %f' % (metric.name, metric_value)
...@@ -501,7 +504,11 @@ def run_customized_training_loop( ...@@ -501,7 +504,11 @@ def run_customized_training_loop(
summary_writer.flush() summary_writer.flush()
logging.info(training_status) logging.info(training_status)
if current_step % steps_per_epoch == 0: # If no need for evaluation, we only call on_batch_end with train_loss,
# this is to ensure we get granular global_step/sec on Tensorboard.
if current_step % steps_between_evals:
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
else:
# Save a submodel with the step in the file name after each epoch. # Save a submodel with the step in the file name after each epoch.
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint( _save_checkpoint(
...@@ -514,7 +521,6 @@ def run_customized_training_loop( ...@@ -514,7 +521,6 @@ def run_customized_training_loop(
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step, logs = _run_evaluation(current_step,
...@@ -523,8 +529,15 @@ def run_customized_training_loop( ...@@ -523,8 +529,15 @@ def run_customized_training_loop(
eval_loss_metric.reset_states() eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric.reset_states() metric.reset_states()
# We add train_loss here rather than call on_batch_end twice to make
# sure that no duplicated values are generated.
logs['loss'] = train_loss
callback_list.on_batch_end(current_step - 1, logs)
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs) # Calls on_epoch_end after each real epoch ends to prevent mis-calculation
# of training steps.
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
...@@ -532,14 +545,11 @@ def run_customized_training_loop( ...@@ -532,14 +545,11 @@ def run_customized_training_loop(
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
logs = _run_evaluation(current_step, logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs) callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric), 'train_loss': _float_metric_value(train_loss_metric),
...@@ -557,4 +567,6 @@ def run_customized_training_loop( ...@@ -557,4 +567,6 @@ def run_customized_training_loop(
if not _should_export_summary(strategy): if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir) tf.io.gfile.rmtree(summary_dir)
callback_list.on_train_end()
return model return model
...@@ -258,6 +258,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -258,6 +258,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
loss_fn=tf.keras.losses.categorical_crossentropy, loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir, model_dir=model_dir,
steps_per_epoch=20, steps_per_epoch=20,
num_eval_per_epoch=4,
steps_per_loop=10, steps_per_loop=10,
epochs=2, epochs=2,
train_input_fn=input_fn, train_input_fn=input_fn,
...@@ -269,14 +270,15 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -269,14 +270,15 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
run_eagerly=False) run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})]) self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end) epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2]) self.assertEqual(list(epoch_ends), [1, 2, 2])
for info in epoch_end_infos: for info in epoch_end_infos:
self.assertIn('accuracy', info) self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin, self.assertEqual(callback.batch_begin, [(0, {}), (5, {}), (10, {}),
[(0, {}), (10, {}), (20, {}), (30, {})]) (15, {}), (20, {}), (25, {}),
(30, {}), (35, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end) batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39]) self.assertEqual(list(batch_ends), [4, 9, 14, 19, 24, 29, 34, 39])
for info in batch_end_infos: for info in batch_end_infos:
self.assertIn('loss', info) self.assertIn('loss', info)
......
...@@ -45,6 +45,9 @@ assemble new layers, networks, or models. ...@@ -45,6 +45,9 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to should be masked), the output will have masked positions set to
approximately zero. approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes the
embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks. embeddings, commonly used by classification tasks.
......
...@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import * ...@@ -18,6 +18,7 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import * from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
......
...@@ -25,91 +25,74 @@ from official.modeling import tf_utils ...@@ -25,91 +25,74 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.Model): class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling. """Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network. This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method. It assumes that the network being passed has a "get_embedding_table()" method.
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. embedding_table: The embedding table of the targets.
num_predictions: The number of predictions to make per sequence. activation: The activation, if any, for the dense layer.
source_network: The network with the embedding layer to use for the initializer: The intializer for the dense layer. Defaults to a Glorot
embedding layer. uniform initializer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
def __init__(self, def __init__(self,
input_width, embedding_table,
num_predictions,
source_network,
embedding_table=None,
activation=None, activation=None,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
name='cls/predictions',
**kwargs): **kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if embedding_table is None: if output not in ('predictions', 'logits'):
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
masked_lm_positions = tf.keras.layers.Input(
shape=(num_predictions,), name='masked_lm_positions', dtype=tf.int32)
masked_lm_input = tf.keras.layers.Lambda(
lambda x: self._gather_indexes(x[0], x[1]))(
[sequence_data, masked_lm_positions])
lm_data = (
tf.keras.layers.Dense(
hidden_size,
activation=activation,
kernel_initializer=initializer,
name='cls/predictions/transform/dense')(masked_lm_input))
lm_data = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='cls/predictions/transform/LayerNorm')(
lm_data)
lm_data = tf.keras.layers.Lambda(
lambda x: tf.matmul(x, embedding_table, transpose_b=True))(
lm_data)
logits = Bias(
initializer=tf.keras.initializers.Zeros(),
name='cls/predictions/output_bias')(
lm_data)
# We can't use the standard Keras reshape layer here, since it expects
# the input and output batch size to be the same.
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))
self.logits = reshape_layer(logits)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
self._output_type = output
super(MaskedLM, self).__init__( def build(self, input_shape):
inputs=[sequence_data, masked_lm_positions], self._vocab_size, hidden_size = self.embedding_table.shape
outputs=output_tensors, self.dense = tf.keras.layers.Dense(
**kwargs) hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self): def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized at this ' raise NotImplementedError('MaskedLM cannot be directly serialized because '
'time. Please use it only in Layers or ' 'it has variable sharing logic.')
'functionally subclassed Models/Networks.')
def _gather_indexes(self, sequence_tensor, positions): def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions. """Gathers the vectors at the specific positions.
...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model): ...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model):
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
# Temporary until we can create a Dense layer that ties the embedding.
class Bias(tf.keras.layers.Layer):
"""Adds a bias term to an input."""
def __init__(self,
initializer='zeros',
regularizer=None,
constraint=None,
activation=None,
**kwargs):
super(Bias, self).__init__(**kwargs)
self._initializer = tf.keras.initializers.get(initializer)
self._regularizer = tf.keras.regularizers.get(regularizer)
self._constraint = tf.keras.constraints.get(constraint)
self._activation = tf.keras.activations.get(activation)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self._bias = self.add_weight(
'bias',
shape=input_shape[1:],
initializer=self._initializer,
regularizer=self._regularizer,
constraint=self._constraint,
dtype=self._dtype,
trainable=True)
super(Bias, self).build(input_shape)
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
'initializer': tf.keras.initializers.serialize(self._initializer),
'regularizer': tf.keras.regularizers.serialize(self._regularizer),
'constraint': tf.keras.constraints.serialize(self._constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
outputs = tf.nn.bias_add(inputs, self._bias)
if self._activation is not None:
return self._activation(outputs) # pylint: disable=not-callable
else:
return outputs
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import masked_lm from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import transformer_encoder from official.nlp.modeling.networks import transformer_encoder
...@@ -32,13 +32,12 @@ from official.nlp.modeling.networks import transformer_encoder ...@@ -32,13 +32,12 @@ from official.nlp.modeling.networks import transformer_encoder
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class MaskedLMTest(keras_parameterized.TestCase): class MaskedLMTest(keras_parameterized.TestCase):
def create_network(self, def create_layer(self,
vocab_size, vocab_size,
sequence_length, sequence_length,
hidden_size, hidden_size,
num_predictions, output='predictions',
output='predictions', xformer_stack=None):
xformer_stack=None):
# First, create a transformer stack that we can use to get the LM's # First, create a transformer stack that we can use to get the LM's
# vocabulary weight. # vocabulary weight.
if xformer_stack is None: if xformer_stack is None:
...@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -49,82 +48,32 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_network = masked_lm.MaskedLM( test_layer = masked_lm.MaskedLM(
num_predictions=num_predictions, embedding_table=xformer_stack.get_embedding_table(),
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
output=output) output=output)
return test_network return test_layer
def test_network_creation(self): def test_layer_creation(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
num_predictions = 21 num_predictions = 21
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size)
num_predictions=num_predictions)
# Make sure that the output tensor of the masked LM is the right shape. # Make sure that the output tensor of the masked LM is the right shape.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions=masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions])
expected_output_shape = [None, num_predictions, vocab_size] expected_output_shape = [None, num_predictions, vocab_size]
self.assertEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation_with_internal_logits(self): def test_layer_invocation_with_external_logits(self):
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
test_network = self.create_network(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions])
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
logits_model = tf.keras.Model(test_network.inputs, test_network.logits)
# Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_external_logits(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
...@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -136,31 +85,28 @@ class MaskedLMTest(keras_parameterized.TestCase):
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='predictions') output='predictions')
logit_network = self.create_network( logit_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_predictions=num_predictions,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='logits') output='logits')
logit_network.set_weights(test_network.get_weights())
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions]) logit_output = logit_layer(lm_input_tensor, masked_positions)
logit_output = logit_network([lm_input_tensor, masked_lm_positions]) logit_output = tf.keras.layers.Activation(tf.nn.log_softmax)(logit_output)
logit_layer.set_weights(test_layer.get_weights())
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output) model = tf.keras.Model([lm_input_tensor, masked_positions], output)
logits_model = tf.keras.Model(([lm_input_tensor, masked_lm_positions]), logits_model = tf.keras.Model(([lm_input_tensor, masked_positions]),
logit_output) logit_output)
# Invoke the masked LM on some fake data to make sure there are no runtime # Invoke the masked LM on some fake data to make sure there are no runtime
...@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -169,40 +115,33 @@ class MaskedLMTest(keras_parameterized.TestCase):
lm_input_data = 10 * np.random.random_sample( lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size)) (batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint( masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions)) sequence_length, size=(batch_size, num_predictions))
outputs = model.predict([lm_input_data, masked_position_data]) # ref_outputs = model.predict([lm_input_data, masked_position_data])
logits = logits_model.predict([lm_input_data, masked_position_data]) # outputs = logits_model.predict([lm_input_data, masked_position_data])
ref_outputs = model([lm_input_data, masked_position_data])
outputs = logits_model([lm_input_data, masked_position_data])
# Ensure that the tensor shapes are correct. # Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, num_predictions, vocab_size) expected_output_shape = (batch_size, num_predictions, vocab_size)
self.assertEqual(expected_output_shape, ref_outputs.shape)
self.assertEqual(expected_output_shape, outputs.shape) self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape) self.assertAllClose(ref_outputs, outputs)
# Ensure that the logits, when softmaxed, create the outputs. def test_layer_invocation(self):
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation(self):
vocab_size = 100 vocab_size = 100
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
num_predictions = 21 num_predictions = 21
test_network = self.create_network( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length, sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size)
num_predictions=num_predictions)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_positions = tf.keras.Input(shape=(num_predictions,), dtype=tf.int32)
shape=(num_predictions,), dtype=tf.int32) output = test_layer(lm_input_tensor, masked_positions)
output = test_network([lm_input_tensor, masked_lm_positions]) model = tf.keras.Model([lm_input_tensor, masked_positions], output)
model = tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
# Invoke the masked LM on some fake data to make sure there are no runtime # Invoke the masked LM on some fake data to make sure there are no runtime
# errors in the code. # errors in the code.
...@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -215,12 +154,8 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self): def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_network( _ = self.create_layer(
vocab_size=8, vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
sequence_length=8,
hidden_size=8,
num_predictions=8,
output='bad')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy from official.nlp.modeling.losses import weighted_sparse_categorical_crossentropy
...@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -48,20 +49,18 @@ class ClassificationLossTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_outputs, _ = xformer_stack([word_ids, mask, type_ids]) _ = xformer_stack([word_ids, mask, type_ids])
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_network = networks.MaskedLM( test_layer = layers.MaskedLM(
num_predictions=num_predictions, embedding_table=xformer_stack.get_embedding_table(),
input_width=lm_outputs.shape[-1],
source_network=xformer_stack,
output=output) output=output)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
masked_lm_positions = tf.keras.Input( masked_lm_positions = tf.keras.Input(
shape=(num_predictions,), dtype=tf.int32) shape=(num_predictions,), dtype=tf.int32)
output = test_network([lm_input_tensor, masked_lm_positions]) output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output) return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes): def create_classification_model(self, input_width, num_classes):
......
...@@ -25,6 +25,7 @@ from typing import List, Optional ...@@ -25,6 +25,7 @@ from typing import List, Optional
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -47,8 +48,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used. "network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM network. activation: The activation (if any) to use in the masked LM network. If
If None, no activation will be used. None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
...@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model): ...@@ -106,16 +107,16 @@ class BertPretrainer(tf.keras.Model):
dtype=tf.int32) dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
self.masked_lm = networks.MaskedLM( if embedding_table is None:
num_predictions=num_token_predictions, embedding_table = self.encoder.get_embedding_table()
input_width=sequence_output.shape[-1], self.masked_lm = layers.MaskedLM(
source_network=network,
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=initializer,
output=output, output=output,
name='masked_lm') name='cls/predictions')
lm_outputs = self.masked_lm([sequence_output, masked_lm_positions]) lm_outputs = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
self.classification = networks.Classification( self.classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
...@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model): ...@@ -126,7 +127,9 @@ class BertPretrainer(tf.keras.Model):
sentence_outputs = self.classification(cls_output) sentence_outputs = self.classification(cls_output)
super(BertPretrainer, self).__init__( super(BertPretrainer, self).__init__(
inputs=inputs, outputs=[lm_outputs, sentence_outputs], **kwargs) inputs=inputs,
outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
**kwargs)
def get_config(self): def get_config(self):
return self._config return self._config
...@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -151,8 +154,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM. num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. mlm_activation: The activation (if any) to use in the masked LM network. If
If None, no activation will be used. None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer. to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder classification_heads: A list of optional head layers to transform on encoder
...@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -193,17 +196,18 @@ class BertPretrainerV2(tf.keras.Model):
outputs = dict() outputs = dict()
if num_masked_tokens > 0: if num_masked_tokens > 0:
self.masked_lm = networks.MaskedLM( self.masked_lm = layers.MaskedLM(
num_predictions=num_masked_tokens, embedding_table=self.encoder_network.get_embedding_table(),
input_width=sequence_output.shape[-1],
source_network=self.encoder_network,
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=mlm_initializer,
name='masked_lm') name='cls/predictions')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1]) masked_lm_positions = tf.keras.layers.Input(
shape=(num_masked_tokens,),
name='masked_lm_positions',
dtype=tf.int32)
inputs.append(masked_lm_positions) inputs.append(masked_lm_positions)
outputs['lm_output'] = self.masked_lm( outputs['lm_output'] = self.masked_lm(
[sequence_output, masked_lm_positions]) sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
outputs[cls_head.name] = cls_head(sequence_output) outputs[cls_head.name] = cls_head(sequence_output)
......
...@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -50,16 +50,19 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) masked_lm_positions = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs, cls_outs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model(
[word_ids, mask, type_ids, masked_lm_positions])
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_lm_shape, lm_outs.shape.as_list()) self.assertAllEqual(expected_lm_shape, outputs['masked_lm'].shape.as_list())
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list()) self.assertAllEqual(expected_classification_shape,
outputs['classification'].shape.as_list())
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
...@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -81,7 +84,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the tensors. In Eager mode, this does the # Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is # actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.) # too complex: this simply ensures we're not hitting runtime errors.)
_, _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) _ = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
...@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -123,7 +126,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
lm_mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) lm_mask = tf.keras.Input(shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask]) outputs = bert_trainer_model([word_ids, mask, type_ids, lm_mask])
......
...@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations] ...@@ -16,8 +16,6 @@ Self-supervised Learning of Language Representations]
(https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters (https://arxiv.org/abs/1909.11942). Compared with [BERT](https://arxiv.org/abs/1810.04805), ALBERT refactorizes embedding parameters
into two smaller matrices and shares parameters across layers. into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is * [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set intended for use as a classification or regression (if number of classes is set
to 1) head. to 1) head.
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.masked_lm import MaskedLM
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
...@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task): ...@@ -43,25 +43,25 @@ class MaskedLMTask(base_task.Task):
return bert.instantiate_from_cfg(self.task_config.network) return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, def build_losses(self,
features, labels,
model_outputs, model_outputs,
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1) lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=features['masked_lm_ids'], labels=labels['masked_lm_ids'],
predictions=lm_output, predictions=lm_output,
weights=features['masked_lm_weights']) weights=labels['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features: if 'next_sentence_labels' in labels:
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable. if policy.name == 'mixed_bfloat16': # b/158514794: bf16 is not stable.
policy = tf.float32 policy = tf.float32
predictions = tf.keras.layers.Activation( predictions = tf.keras.layers.Activation(
tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence']) tf.nn.log_softmax, dtype=policy)(model_outputs['next_sentence'])
sentence_labels = features['next_sentence_labels'] sentence_labels = labels['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels, labels=sentence_labels,
predictions=predictions) predictions=predictions)
...@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task): ...@@ -112,15 +112,15 @@ class MaskedLMTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss')) metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics return metrics
def process_metrics(self, metrics, inputs, outputs): def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics: if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(inputs['masked_lm_ids'], metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
outputs['lm_output'], model_outputs['lm_output'],
inputs['masked_lm_weights']) labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics: if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state( metrics['next_sentence_accuracy'].update_state(
inputs['next_sentence_labels'], outputs['next_sentence']) labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model, def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics): optimizer: tf.keras.optimizers.Optimizer, metrics):
...@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task): ...@@ -139,7 +139,7 @@ class MaskedLMTask(base_task.Task):
outputs = model(inputs, training=True) outputs = model(inputs, training=True)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
features=inputs, labels=inputs,
model_outputs=outputs, model_outputs=outputs,
metrics=metrics, metrics=metrics,
aux_losses=model.losses) aux_losses=model.losses)
...@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task): ...@@ -166,7 +166,7 @@ class MaskedLMTask(base_task.Task):
""" """
outputs = self.inference_step(inputs, model) outputs = self.inference_step(inputs, model)
loss = self.build_losses( loss = self.build_losses(
features=inputs, labels=inputs,
model_outputs=outputs, model_outputs=outputs,
metrics=metrics, metrics=metrics,
aux_losses=model.losses) aux_losses=model.losses)
......
...@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib ...@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass @dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig): class SentencePredictionConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
# At most one of `pretrain_checkpoint_dir` and `hub_module_url` can # At most one of `init_checkpoint` and `hub_module_url` can
# be specified. # be specified.
pretrain_checkpoint_dir: str = '' init_checkpoint: str = ''
hub_module_url: str = '' hub_module_url: str = ''
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig( network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
num_masked_tokens=0, num_masked_tokens=0,
...@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task):
def __init__(self, params=cfg.TaskConfig): def __init__(self, params=cfg.TaskConfig):
super(SentencePredictionTask, self).__init__(params) super(SentencePredictionTask, self).__init__(params)
if params.hub_module_url and params.pretrain_checkpoint_dir: if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.') '`pretrain_checkpoint_dir` can be specified.')
if params.hub_module_url: if params.hub_module_url:
...@@ -79,12 +79,11 @@ class SentencePredictionTask(base_task.Task): ...@@ -79,12 +79,11 @@ class SentencePredictionTask(base_task.Task):
else: else:
return bert.instantiate_from_cfg(self.task_config.network) return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
labels = features
loss = loss_lib.weighted_sparse_categorical_crossentropy_loss( loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=labels, labels=labels,
predictions=tf.nn.log_softmax(model_outputs['sentence_prediction'], predictions=tf.nn.log_softmax(
axis=-1)) model_outputs['sentence_prediction'], axis=-1))
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
...@@ -93,6 +92,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -93,6 +92,7 @@ class SentencePredictionTask(base_task.Task):
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy': if params.input_path == 'dummy':
def dummy_data(_): def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict( x = dict(
...@@ -113,22 +113,22 @@ class SentencePredictionTask(base_task.Task): ...@@ -113,22 +113,22 @@ class SentencePredictionTask(base_task.Task):
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
metrics = [ metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
]
return metrics return metrics
def process_metrics(self, metrics, labels, outputs): def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics: for metric in metrics:
metric.update_state(labels, outputs['sentence_prediction']) metric.update_state(labels, model_outputs['sentence_prediction'])
def process_compiled_metrics(self, compiled_metrics, labels, outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels, outputs['sentence_prediction']) compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
def initialize(self, model): def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0.""" """Load a pretrained checkpoint (if exists) and then train from iter 0."""
pretrain_ckpt_dir = self.task_config.pretrain_checkpoint_dir ckpt_dir_or_file = self.task_config.init_checkpoint
if not pretrain_ckpt_dir: if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return return
pretrain2finetune_mapping = { pretrain2finetune_mapping = {
...@@ -138,10 +138,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -138,10 +138,7 @@ class SentencePredictionTask(base_task.Task):
model.checkpoint_items['sentence_prediction.pooler_dense'], model.checkpoint_items['sentence_prediction.pooler_dense'],
} }
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping) ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
latest_pretrain_ckpt = tf.train.latest_checkpoint(pretrain_ckpt_dir) status = ckpt.restore(ckpt_dir_or_file)
if latest_pretrain_ckpt is None:
raise FileNotFoundError(
'Cannot find pretrain checkpoint under {}'.format(pretrain_ckpt_dir))
status = ckpt.restore(latest_pretrain_ckpt)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint.') logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = sentence_prediction.SentencePredictionConfig( config = sentence_prediction.SentencePredictionConfig(
init_checkpoint=self.get_temp_dir(),
network=bert.BertPretrainerConfig( network=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=0, num_masked_tokens=0,
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
...@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer, metrics=metrics) task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=3, name="next_sentence")
])
pretrain_model = bert.instantiate_from_cfg(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self): def _export_bert_tfhub(self):
bert_config = configs.BertConfig( bert_config = configs.BertConfig(
vocab_size=30522, vocab_size=30522,
......
...@@ -218,7 +218,7 @@ def get_callbacks(): ...@@ -218,7 +218,7 @@ def get_callbacks():
time_callback = keras_utils.TimeHistory( time_callback = keras_utils.TimeHistory(
FLAGS.batch_size, FLAGS.batch_size,
FLAGS.log_steps, FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None) logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback) callbacks.append(time_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment