"vscode:/vscode.git/clone" did not exist on "46f16a5e0b065865d9627ddbc179dc2ef2e7d802"
Commit 5cc6df63 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Open source dual encoder tasks and dataloaders.

PiperOrigin-RevId: 408397786
parent e97979cb
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads dataset for the dual encoder (retrieval) task."""
import functools
import itertools
from typing import Iterable, Mapping, Optional, Tuple
import dataclasses
import tensorflow as tf
import tensorflow_hub as hub
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
@dataclasses.dataclass
class DualEncoderDataConfig(cfg.DataConfig):
"""Data config for dual encoder task (tasks/dual_encoder)."""
# Either set `input_path`...
input_path: str = ''
# ...or `tfds_name` and `tfds_split` to specify input.
tfds_name: str = ''
tfds_split: str = ''
global_batch_size: int = 32
# Either build preprocessing with Python code by specifying these values...
vocab_file: str = ''
lower_case: bool = True
# ...or load preprocessing from a SavedModel at this location.
preprocessing_hub_module_url: str = ''
left_text_fields: Tuple[str] = ('left_input',)
right_text_fields: Tuple[str] = ('right_input',)
is_training: bool = True
seq_length: int = 128
@data_loader_factory.register_data_loader_cls(DualEncoderDataConfig)
class DualEncoderDataLoader(data_loader.DataLoader):
"""A class to load dataset for dual encoder task (tasks/dual_encoder)."""
def __init__(self, params):
if bool(params.tfds_name) == bool(params.input_path):
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
'or `input_path`.')
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
raise ValueError('Must specify exactly one of vocab_file (with matching '
'lower_case flag) or preprocessing_hub_module_url.')
self._params = params
self._seq_length = params.seq_length
self._left_text_fields = params.left_text_fields
self._right_text_fields = params.right_text_fields
if params.preprocessing_hub_module_url:
preprocessing_hub_module = hub.load(params.preprocessing_hub_module_url)
self._tokenizer = preprocessing_hub_module.tokenize
self._pack_inputs = functools.partial(
preprocessing_hub_module.bert_pack_inputs,
seq_length=params.seq_length)
else:
self._tokenizer = layers.BertTokenizer(
vocab_file=params.vocab_file, lower_case=params.lower_case)
self._pack_inputs = layers.BertPackInputs(
seq_length=params.seq_length,
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
x: tf.io.FixedLenFeature([], tf.string)
for x in itertools.chain(
*[self._left_text_fields, self._right_text_fields])
}
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _bert_tokenize(
self, record: Mapping[str, tf.Tensor],
text_fields: Iterable[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Tokenize the input in text_fields using BERT tokenizer.
Args:
record: A tfexample record contains the features.
text_fields: A list of fields to be tokenzied.
Returns:
The tokenized features in a tuple of (input_word_ids, input_mask,
input_type_ids).
"""
segments_text = [record[x] for x in text_fields]
segments_tokens = [self._tokenizer(s) for s in segments_text]
segments = [tf.cast(x.merge_dims(1, 2), tf.int32) for x in segments_tokens]
return self._pack_inputs(segments)
def _bert_preprocess(
self, record: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Perform the bert word piece tokenization for left and right inputs."""
def _switch_prefix(string, old, new):
if string.startswith(old): return new + string[len(old):]
raise ValueError('Expected {} to start with {}'.format(string, old))
def _switch_key_prefix(d, old, new):
return {_switch_prefix(key, old, new): value for key, value in d.items()}
model_inputs = _switch_key_prefix(
self._bert_tokenize(record, self._left_text_fields),
'input_', 'left_')
model_inputs.update(_switch_key_prefix(
self._bert_tokenize(record, self._right_text_fields),
'input_', 'right_'))
return model_inputs
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
# Skip `decoder_fn` for tfds input.
decoder_fn=self._decode if self._params.input_path else None,
dataset_fn=tf.data.TFRecordDataset,
postprocess_fn=self._bert_preprocess)
return reader.read(input_context)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.dual_encoder_dataloader."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.data import dual_encoder_dataloader
_LEFT_FEATURE_NAME = 'left_input'
_RIGHT_FEATURE_NAME = 'right_input'
def _create_fake_dataset(output_path):
"""Creates a fake dataset contains examples for training a dual encoder model.
The created dataset contains examples with two byteslist features keyed by
_LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
Args:
output_path: The output path of the fake dataset.
"""
def create_str_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
with tf.io.TFRecordWriter(output_path) as writer:
for _ in range(100):
features = {}
features[_LEFT_FEATURE_NAME] = create_str_feature([b'hello world.'])
features[_RIGHT_FEATURE_NAME] = create_str_feature([b'world hello.'])
tf_example = tf.train.Example(
features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
def _make_vocab_file(vocab, output_path):
with tf.io.gfile.GFile(output_path, 'w') as f:
f.write('\n'.join(vocab + ['']))
class DualEncoderDataTest(tf.test.TestCase, parameterized.TestCase):
def test_load_dataset(self):
seq_length = 16
batch_size = 10
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_create_fake_dataset(train_data_path)
_make_vocab_file(
['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'], vocab_path)
data_config = dual_encoder_dataloader.DualEncoderDataConfig(
input_path=train_data_path,
seq_length=seq_length,
vocab_file=vocab_path,
lower_case=True,
left_text_fields=(_LEFT_FEATURE_NAME,),
right_text_fields=(_RIGHT_FEATURE_NAME,),
global_batch_size=batch_size)
dataset = dual_encoder_dataloader.DualEncoderDataLoader(
data_config).load()
features = next(iter(dataset))
self.assertCountEqual(
['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
'right_mask', 'right_type_ids'],
features.keys())
self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
@parameterized.parameters(False, True)
def test_load_tfds(self, use_preprocessing_hub):
seq_length = 16
batch_size = 10
if use_preprocessing_hub:
vocab_path = ''
preprocessing_hub = (
'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
else:
vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
_make_vocab_file(
['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
vocab_path)
preprocessing_hub = ''
data_config = dual_encoder_dataloader.DualEncoderDataConfig(
tfds_name='para_crawl/enmt',
tfds_split='train',
seq_length=seq_length,
vocab_file=vocab_path,
lower_case=True,
left_text_fields=('en',),
right_text_fields=('mt',),
preprocessing_hub_module_url=preprocessing_hub,
global_batch_size=batch_size)
dataset = dual_encoder_dataloader.DualEncoderDataLoader(
data_config).load()
features = next(iter(dataset))
self.assertCountEqual(
['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
'right_mask', 'right_type_ids'],
features.keys())
self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dual encoder (retrieval) task."""
from typing import Mapping, Tuple
# Import libraries
from absl import logging
import dataclasses
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A dual encoder (retrieval) configuration."""
# Normalize input embeddings if set to True.
normalize: bool = True
# Maximum input sequence length.
max_sequence_length: int = 64
# Parameters for training a dual encoder model with additive margin, see
# https://www.ijcai.org/Proceedings/2019/0746.pdf for more details.
logit_scale: float = 1
logit_margin: float = 0
bidirectional: bool = False
# Defining k for calculating metrics recall@k.
eval_top_k: Tuple[int, ...] = (1, 3, 10)
encoder: encoders.EncoderConfig = (
encoders.EncoderConfig())
@dataclasses.dataclass
class DualEncoderConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
# Defines the concrete model config at instantiation time.
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(DualEncoderConfig)
class DualEncoderTask(base_task.Task):
"""Task object for dual encoder."""
def build_model(self):
"""Interface to build model. Refer to base_task.Task.build_model."""
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
# Currently, we only supports bert-style dual encoder.
return models.DualEncoder(
network=encoder_network,
max_seq_length=self.task_config.model.max_sequence_length,
normalize=self.task_config.model.normalize,
logit_scale=self.task_config.model.logit_scale,
logit_margin=self.task_config.model.logit_margin,
output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Interface to compute losses. Refer to base_task.Task.build_losses."""
del labels
left_logits = model_outputs['left_logits']
right_logits = model_outputs['right_logits']
batch_size = tf_utils.get_shape_list(left_logits, name='batch_size')[0]
ranking_labels = tf.range(batch_size)
loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=left_logits))
if self.task_config.model.bidirectional:
right_rank_loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=right_logits))
loss += right_rank_loss
return tf.reduce_mean(loss)
def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path != 'dummy':
return data_loader_factory.get_data_loader(params).load(input_context)
def dummy_data(_):
dummy_ids = tf.zeros((10, params.seq_length), dtype=tf.int32)
x = dict(
left_word_ids=dummy_ids,
left_mask=dummy_ids,
left_type_ids=dummy_ids,
right_word_ids=dummy_ids,
right_mask=dummy_ids,
right_type_ids=dummy_ids)
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def build_metrics(self, training=None):
del training
metrics = [tf.keras.metrics.Mean(name='batch_size_per_core')]
for k in self.task_config.model.eval_top_k:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'left_recall_at_{k}'))
if self.task_config.model.bidirectional:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'right_recall_at_{k}'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
del labels
metrics = dict([(metric.name, metric) for metric in metrics])
left_logits = model_outputs['left_logits']
right_logits = model_outputs['right_logits']
batch_size = tf_utils.get_shape_list(
left_logits, name='sequence_output_tensor')[0]
ranking_labels = tf.range(batch_size)
for k in self.task_config.model.eval_top_k:
metrics[f'left_recall_at_{k}'].update_state(ranking_labels, left_logits)
if self.task_config.model.bidirectional:
metrics[f'right_recall_at_{k}'].update_state(ranking_labels,
right_logits)
metrics['batch_size_per_core'].update_state(batch_size)
def validation_step(self,
inputs,
model: tf.keras.Model,
metrics=None) -> Mapping[str, tf.Tensor]:
outputs = model(inputs)
loss = self.build_losses(
labels=None, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, None, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, None, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'],
}
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import dual_encoder_dataloader
from official.nlp.tasks import dual_encoder
from official.nlp.tasks import masked_lm
from official.nlp.tools import export_tfhub_lib
class DualEncoderTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(DualEncoderTaskTest, self).setUp()
self._train_data_config = (
dual_encoder_dataloader.DualEncoderDataConfig(
input_path="dummy", seq_length=32))
def get_model_config(self):
return dual_encoder.ModelConfig(
max_sequence_length=32,
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
def _run_task(self, config):
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
dataset.batch(10)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
def test_task(self):
config = dual_encoder.DualEncoderConfig(
init_checkpoint=self.get_temp_dir(),
model=self.get_model_config(),
train_data=self._train_data_config)
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=4)
encoder = export_tfhub_lib.get_bert_encoder(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(encoder=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
export_path = os.path.join(self.get_temp_dir(), "hub")
export_tfhub_lib.export_model(
export_path,
bert_config=bert_config,
encoder_config=None,
model_checkpoint_path=model_checkpoint_path,
vocab_file=vocab_file,
do_lower_case=True,
with_mlm=False)
return export_path
def test_task_with_hub(self):
hub_module_url = self._export_bert_tfhub()
config = dual_encoder.DualEncoderConfig(
hub_module_url=hub_module_url,
model=self.get_model_config(),
train_data=self._train_data_config)
self._run_task(config)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment