Commit c31ae6da authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327257703
parent a54f7d00
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dual encoder (retrieval) task."""
from typing import Mapping, Tuple
# Import libraries
from absl import logging
import dataclasses
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class ModelConfig(base_config.Config):
"""A dual encoder (retrieval) configuration."""
# Normalize input embeddings if set to True.
normalize: bool = True
# Maximum input sequence length.
max_sequence_length: int = 64
# Parameters for training a dual encoder model with additive margin, see
# https://www.ijcai.org/Proceedings/2019/0746.pdf for more details.
logit_scale: float = 1
logit_margin: float = 0
bidirectional: bool = False
# Defining k for calculating metrics recall@k.
eval_top_k: Tuple[int, ...] = (1, 3, 10)
encoder: encoders.EncoderConfig = (
encoders.EncoderConfig())
@dataclasses.dataclass
class DualEncoderConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can
# be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
# Defines the concrete model config at instantiation time.
model: ModelConfig = ModelConfig()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@task_factory.register_task_cls(DualEncoderConfig)
class DualEncoderTask(base_task.Task):
"""Task object for dual encoder."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(DualEncoderTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
def build_model(self):
"""Interface to build model. Refer to base_task.Task.build_model."""
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
# Currently, we only supports bert-style dual encoder.
return models.DualEncoder(
network=encoder_network,
max_seq_length=self.task_config.model.max_sequence_length,
normalize=self.task_config.model.normalize,
logit_scale=self.task_config.model.logit_scale,
logit_margin=self.task_config.model.logit_margin,
output='logits')
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Interface to compute losses. Refer to base_task.Task.build_losses."""
del labels
left_logits, right_logits = model_outputs
batch_size = tf_utils.get_shape_list(left_logits, name='batch_size')[0]
ranking_labels = tf.range(batch_size)
loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=left_logits))
if self.task_config.model.bidirectional:
right_rank_loss = tf_utils.safe_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=ranking_labels,
logits=right_logits))
loss += right_rank_loss
return tf.reduce_mean(loss)
def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path != 'dummy':
return data_loader_factory.get_data_loader(params).load(input_context)
def dummy_data(_):
dummy_ids = tf.zeros((10, params.seq_length), dtype=tf.int32)
x = dict(
left_word_ids=dummy_ids,
left_mask=dummy_ids,
left_type_ids=dummy_ids,
right_word_ids=dummy_ids,
right_mask=dummy_ids,
right_type_ids=dummy_ids)
return x
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def build_metrics(self, training=None):
del training
metrics = [tf.keras.metrics.Mean(name='batch_size_per_core')]
for k in self.task_config.model.eval_top_k:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'left_recall_at_{k}'))
if self.task_config.model.bidirectional:
metrics.append(tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name=f'right_recall_at_{k}'))
return metrics
def process_metrics(self, metrics, labels, model_outputs):
del labels
metrics = dict([(metric.name, metric) for metric in metrics])
left_logits, right_logits = model_outputs
batch_size = tf_utils.get_shape_list(
left_logits, name='sequence_output_tensor')[0]
ranking_labels = tf.range(batch_size)
for k in self.task_config.model.eval_top_k:
metrics[f'left_recall_at_{k}'].update_state(ranking_labels, left_logits)
if self.task_config.model.bidirectional:
metrics[f'right_recall_at_{k}'].update_state(ranking_labels,
right_logits)
metrics['batch_size_per_core'].update_state(batch_size)
def validation_step(self,
inputs,
model: tf.keras.Model,
metrics=None) -> Mapping[str, tf.Tensor]:
outputs = model(inputs)
loss = self.build_losses(
labels=None, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, None, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, None, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
pretrain2finetune_mapping = {
'encoder': model.checkpoint_items['encoder'],
}
ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os
from absl.testing import parameterized
import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import dual_encoder_dataloader
from official.nlp.tasks import dual_encoder
from official.nlp.tasks import masked_lm
class DualEncoderTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(DualEncoderTaskTest, self).setUp()
self._train_data_config = (
dual_encoder_dataloader.DualEncoderDataConfig(
input_path="dummy", seq_length=32))
def get_model_config(self):
return dual_encoder.ModelConfig(
max_sequence_length=32,
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
def _run_task(self, config):
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
dataset.batch(10)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
def test_task(self):
config = dual_encoder.DualEncoderConfig(
init_checkpoint=self.get_temp_dir(),
model=self.get_model_config(),
train_data=self._train_data_config)
task = dual_encoder.DualEncoderTask(config)
model = task.build_model()
metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer, metrics=metrics)
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
pretrain_cfg = bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
ckpt = tf.train.Checkpoint(
model=pretrain_model, **pretrain_model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
def _export_bert_tfhub(self):
bert_config = configs.BertConfig(
vocab_size=30522,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
_, encoder = export_tfhub.create_bert_model(bert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
with tf.io.gfile.GFile(vocab_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
hub_destination, vocab_file)
return hub_destination
def test_task_with_hub(self):
hub_module_url = self._export_bert_tfhub()
config = dual_encoder.DualEncoderConfig(
hub_module_url=hub_module_url,
model=self.get_model_config(),
train_data=self._train_data_config)
self._run_task(config)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment