Commit c9ac3e2c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Set up official/core.

Move nlp/tasks/masked_lm. Make sure task tests do not depend on internal data and orbit, so that they can be open sourced.

PiperOrigin-RevId: 314973560
parent 23c87aaa
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the base task abstraction."""
import functools
from typing import Any, Callable, Optional
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
class Task(tf.Module):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss
and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params=cfg.TaskConfig):
self._task_config = params
@property
def task_config(self) -> cfg.TaskConfig:
return self._task_config
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
Args:
model: The keras.Model built or used by this task.
"""
pass
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
Returns:
A model instance.
"""
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
loss=None,
train_step: Optional[Callable[..., Any]] = None,
validation_step: Optional[Callable[..., Any]] = None,
**kwargs) -> tf.keras.Model:
"""Compiles the model with objects created by the task.
The method should not be used in any customized training implementation.
Args:
model: a keras.Model.
optimizer: the keras optimizer.
loss: a callable/list of losses.
train_step: optional train step function defined by the task.
validation_step: optional validation_step step function defined by the
task.
**kwargs: other kwargs consumed by keras.Model compile().
Returns:
a compiled keras.Model.
"""
if bool(loss is None) == bool(train_step is None):
raise ValueError("`loss` and `train_step` should be exclusive to "
"each other.")
model.compile(optimizer=optimizer, loss=loss, **kwargs)
if train_step:
model.train_step = functools.partial(
train_step, model=model, optimizer=model.optimizer)
if validation_step:
model.test_step = functools.partial(validation_step, model=model)
return model
def build_inputs(self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
Args:
params: hyperparams to create input pipelines.
input_context: optional distribution input pipeline context.
Returns:
A nested structure of per-replica input functions.
"""
pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
features: optional feature/labels tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, features
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
else:
losses = aux_losses
total_loss = tf.add_n(losses)
return total_loss
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
del training
return []
def process_metrics(self, metrics, labels, outputs):
"""Process and update metrics. Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs):
"""Process and update compiled_metrics. call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, outputs)
def train_step(self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
_REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file.
def register_task_cls(task_config: cfg.TaskConfig) -> Task:
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_TASK_CLS, task_config)
def get_task_cls(task_config: cfg.TaskConfig) -> Task:
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config)
return task_cls
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A common dataset reader."""
from typing import Any, Callable, List, Optional
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
def __init__(self,
params: cfg.DataConfig,
shards: Optional[List[str]] = None,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance.
Args:
params: A config_definitions.DataConfig object.
shards: A list of files to be read. If given, read from these files.
Otherwise, read from params.input_path.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be
executed after parser_fn.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
# TODO(chendouble): Support TFDS as input_path.
self._shards = shards
if self._shards:
self._num_files = len(self._shards)
else:
self._input_patterns = params.input_path.strip().split(',')
self._num_files = 0
for input_pattern in self._input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
self._num_files += len(matched_files)
if self._num_files == 0:
raise ValueError('%s does not match any files.' % params.input_path)
self._global_batch_size = params.global_batch_size
self._is_training = params.is_training
self._drop_remainder = params.drop_remainder
self._shuffle_buffer_size = params.shuffle_buffer_size
self._cache = params.cache
self._cycle_length = params.cycle_length
self._sharding = params.sharding
self._examples_consume = params.examples_consume
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn
self._postprocess_fn = postprocess_fn
def _read_sharded_files(
self,
input_context: Optional[tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files."""
# Read from `self._shards` if it is provided.
if self._shards:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns, shuffle=self._is_training)
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
cycle_length=self._cycle_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def _read_single_file(
self,
input_context: Optional[tf.distribute.InputContext] = None):
"""Reads a dataset from a single file."""
# Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._shards or self._input_patterns)
# When `input_file` is a path to a single file, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if self._num_files > 1:
dataset = self._read_sharded_files(input_context)
else:
assert self._num_files == 1
dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None:
dataset = self._dataset_transform_fn(dataset)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
...@@ -21,6 +21,7 @@ import tensorflow as tf ...@@ -21,6 +21,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -77,3 +78,23 @@ def instantiate_from_cfg( ...@@ -77,3 +78,23 @@ def instantiate_from_cfg(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=classification_heads) classification_heads=classification_heads)
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional
import tensorflow as tf
from official.core import input_reader
class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task."""
def __init__(self, params):
"""Inits `BertPretrainDataLoader` class.
Args:
params: A `BertPretrainDataConfig` object.
"""
self._params = params
self._seq_length = params.seq_length
self._max_predictions_per_seq = params.max_predictions_per_seq
self._use_next_sentence_label = params.use_next_sentence_label
self._use_position_id = params.use_position_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids':
tf.io.FixedLenFeature([self._seq_length], tf.int64),
'masked_lm_positions':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
'masked_lm_ids':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
'masked_lm_weights':
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
}
if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64)
if self._use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature(
[self._seq_length], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids'],
'masked_lm_positions': record['masked_lm_positions'],
'masked_lm_ids': record['masked_lm_ids'],
'masked_lm_weights': record['masked_lm_weights'],
}
if self._use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels']
if self._use_position_id:
x['position_ids'] = record['position_ids']
return x
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Masked language task."""
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import losses as loss_lib
@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
"""The model config."""
network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
])
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
class MaskedLMAccuracy(tf.keras.metrics.Mean):
"""The weighted accuracy metric for the masked language model."""
def __init__(self, name=None, dtype=None):
super(MaskedLMAccuracy, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
y_true, y_pred)
numerator = tf.reduce_sum(masked_lm_accuracy * sample_weight)
denominator = tf.reduce_sum(sample_weight) + 1e-5
masked_lm_accuracy = numerator / denominator
return super(MaskedLMAccuracy, self).update_state(masked_lm_accuracy)
@base_task.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
def build_model(self):
return bert.instantiate_from_cfg(self.task_config.network)
def build_losses(self,
features,
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics])
lm_output = tf.nn.log_softmax(model_outputs['lm_output'], axis=-1)
mlm_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=features['masked_lm_ids'],
predictions=lm_output,
weights=features['masked_lm_weights'])
metrics['lm_example_loss'].update_state(mlm_loss)
if 'next_sentence_labels' in features:
sentence_labels = features['next_sentence_labels']
sentence_loss = loss_lib.weighted_sparse_categorical_crossentropy_loss(
labels=sentence_labels,
predictions=tf.nn.log_softmax(
model_outputs['next_sentence'], axis=-1))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
total_loss = mlm_loss
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
return dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids,
masked_lm_positions=dummy_lm,
masked_lm_ids=dummy_lm,
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
def build_metrics(self, training=None):
del training
metrics = [
MaskedLMAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss')
]
# TODO(hongkuny): rethink how to manage metrics creation with heads.
if self.task_config.train_data.use_next_sentence_label:
metrics.append(
tf.keras.metrics.SparseCategoricalAccuracy(
name='next_sentence_accuracy'))
metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
return metrics
def process_metrics(self, metrics, inputs, outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(inputs['masked_lm_ids'],
outputs['lm_output'],
inputs['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
inputs['next_sentence_labels'], outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
# Computes per-replica loss.
loss = self.build_losses(
features=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
def validation_step(self, inputs, model: tf.keras.Model, metrics):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
outputs = self.inference_step(inputs, model)
loss = self.build_losses(
features=inputs,
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.masked_lm."""
import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.tasks import masked_lm
class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.MaskedLMConfig(
network=bert.BertPretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1),
num_masked_tokens=20,
cls_heads=[
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=bert.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
global_batch_size=1))
task = masked_lm.MaskedLMTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment