Commit e1cb663e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Initial release of NHNet: https://arxiv.org/abs/2001.09386

PiperOrigin-RevId: 306256383
parent 057895af
...@@ -19,6 +19,7 @@ with the same or improved speed and performance with each new TensorFlow build. ...@@ -19,6 +19,7 @@ with the same or improved speed and performance with each new TensorFlow build.
| ----- | ----------- | --------- | | ----- | ----------- | --------- |
| [ALBERT](nlp/albert) | A Lite BERT for Self-supervised Learning of Language Representations | [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) | | [ALBERT](nlp/albert) | A Lite BERT for Self-supervised Learning of Language Representations | [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) |
| [BERT](nlp/bert) | A powerful pre-trained language representation model: BERT (Bidirectional Encoder Representations from Transformers) | [arXiv:1810.04805](https://arxiv.org/abs/1810.04805) | | [BERT](nlp/bert) | A powerful pre-trained language representation model: BERT (Bidirectional Encoder Representations from Transformers) | [arXiv:1810.04805](https://arxiv.org/abs/1810.04805) |
| [NHNet](nlp/nhnet) | A transformer-based multi-sequence to sequence model: Generating Representative Headlines for News Stories | [arXiv:2001.09386](https://arxiv.org/abs/2001.09386) |
| [Transformer](nlp/transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) | | [Transformer](nlp/transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [XLNet](nlp/xlnet) | XLNet: Generalized Autoregressive Pretraining for Language Understanding | [arXiv:1906.08237](https://arxiv.org/abs/1906.08237) | | [XLNet](nlp/xlnet) | XLNet: Generalized Autoregressive Pretraining for Language Understanding | [arXiv:1906.08237](https://arxiv.org/abs/1906.08237) |
......
...@@ -7,8 +7,9 @@ state-of-the-art models. ...@@ -7,8 +7,9 @@ state-of-the-art models.
The repository contains the following models, with implementations, pre-trained The repository contains the following models, with implementations, pre-trained
model weights, usage scripts and conversion utilities: model weights, usage scripts and conversion utilities:
* [Bert](bert)
* [Albert](albert) * [Albert](albert)
* [Bert](bert)
* [NHNet](nhnet)
* [XLNet](xlnet) * [XLNet](xlnet)
* [Transformer for translation](transformer) * [Transformer for translation](transformer)
...@@ -16,6 +17,3 @@ Addtional features: ...@@ -16,6 +17,3 @@ Addtional features:
* Distributed trainable on both multi-GPU and TPU * Distributed trainable on both multi-GPU and TPU
* e2e training for custom models, including both pretraining and finetuning. * e2e training for custom models, including both pretraining and finetuning.
# Multi-doc News Headline Generation Model: NHNet
This repository contains TensorFlow 2.x implementation for NHNet [[1]](#1) as
well as instructions for producing the data we described in the paper.
## Introduction
NHNet is a multi-doc news headline generation model. It extends a standard
Transformer-based encoder-decoder model to multi-doc setting and relies on an
article-level attention layer to capture information common to most (if not all)
input news articles in a news cluster or story, and provide robustness against
potential outliers in the input due to clustering quality.
Our academic paper [[1]](#1) which describes NHNet in detail can be found here:
https://arxiv.org/abs/2001.09386.
## Dataset
**Raw Data:** One can [download](https://github.com/google-research-datasets/NewSHead)
our multi-doc headline dataset which
contains 369,940 news stories and 932,571 unique URLs. We split these stories
into train (359,940 stories), validation (5,000 stories) and test set (5,000
stories) by timestamp.
More information, please checkout:
https://github.com/google-research-datasets/NewSHead
### Crawling
Unfortunately, we will not be able to release the pre-processed dataset that is
exactly used in the paper. Users need to crawl the URLs and the recommended
pre-processing is using an open-sourced library to download and parse the news
content including title and leading paragraphs. For ease of this process, we
provide a config of [news-please](https://github.com/fhamborg/news-please) that
will crawl and extract news articles on a local machine.
First, install the `news-please` CLI (requires python 3.x)
```shell
$ pip3 install news-please
```
Next, run the crawler with our provided config and URL list
```shell
# Sets to path of the downloaded data folder
$ DATA_FOLDER=/path/to/downloaded_dataset
# Uses CLI interface to crawl
$ news-please -c $DATA_FOLDER/news_please
```
By default, it will store crawled
articles under `/tmp/nhnet/`. To terminate the process press `CTRL+C`.
The crawling may take some days (48 hours in our test) and it depends on the
network environment and #threads set in the config. As the crawling tool won't
stop automatically, it is not straightforward to check the progress. We suggest
to terminate the job if there are no new articles crawled in a short time period
(e.g., 10 minutes) by running
```shell
$ find /tmp/nhnet -type f | wc -l
```
Please note that it is expected that some URLs are no longer available on the
web as time goes by.
### Data Processing
Given the crawled articles under `/tmp/nhnet/`, we would like to transform these
textual articles into a set of `TFRecord` files containing serialized
tensorflow.Example protocol buffers, with feature keys following the BERT
[[2]](#2) tradition but is extended for multiple text segments. We will later
use these processed TFRecords for training and evaluation.
To do this, please first download a [BERT pretrained checkpoint](https://github.com/tensorflow/models/tree/master/official/nlp/bert#access-to-pretrained-checkpoints)
(`BERT-Base,Uncased` preferred for efficiency) and decompress the `tar.gz` file.
We need the vocabulary file and later use the checkpoint for NHNet
initialization.
Next, we can run the following data preprocess script which may take a few hours
to read files and tokenize article content.
```shell
# Recall that we use DATA_FOLDER=/path/to/downloaded_dataset
$ python3 raw_data_preprocess.py \
-crawled_articles=/tmp/nhnet \
-vocab=/path/to/bert_checkpoint/vocab.txt \
-do_lower_case=True \
-len_title=15 \
-len_passage=200 \
-max_num_articles=5 \
-data_folder=$DATA_FOLDER
```
This python script will export processed train/valid/eval files under
`$DATA_FOLDER/processed/`.
## Training
Please first install TensorFlow 2 and Tensorflow Model Garden following the
[requirments section](https://github.com/tensorflow/models/tree/master/official#requirements).
### CPU/GPU
```shell
$ python3 trainer.py \
--mode=train_and_eval \
--vocab=/path/to/bert_checkpoint/vocab.txt \
--init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
--params_override='init_from_bert2bert=false' \
--train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
--model_dir=/path/to/output/model \
--len_title=15 \
--len_passage=200 \
--max_num_articles=5 \
--model_type=nhnet \
--train_batch_size=16 \
--train_steps=10000 \
--steps_per_loop=1 \
--checkpoint_interval=100
```
### TPU
```shell
$ python3 trainer.py \
--mode=train_and_eval \
--vocab=/path/to/bert_checkpoint/vocab.txt \
--init_checkpoint=/path/to/bert_checkpoint/bert_model.ckpt \
--params_override='init_from_bert2bert=false' \
--train_file_pattern=$DATA_FOLDER/processed/train.tfrecord* \
--model_dir=/path/to/output/model \
--len_title=15 \
--len_passage=200 \
--max_num_articles=5 \
--model_type=nhnet \
--train_batch_size=1024 \
--train_steps=10000 \
--steps_per_loop=1000 \
--checkpoint_interval=1000 \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
In the paper, we train more than 10k steps with batch size set as 1024 with
TPU-v3-64.
Note that, `trainer.py` also supports `train` mode and continuous `eval` mode.
For large scale TPU training, we recommend the have a process running the
`train` mode and another process running the continuous `eval` mode which can
runs on GPUs.
This is the setting we commonly used for large-scale experiments, because `eval`
will be non-blocking to the expensive training load.
### Metrics
**Note: the metrics reported by `evaluation.py` are approximated on
word-piece level rather than the real string tokens. Some metrics like BLEU
scores can be off.**
We will release a colab to evaluate results on string-level soon.
## References
<a id="1">[1]</a> Xiaotao Gu, Yuning Mao, Jiawei Han, Jialu Liu, You Wu, Cong
Yu, Daniel Finnie, Hongkun Yu, Jiaqi Zhai and Nicholas Zukoski "Generating
Representative Headlines for News Stories": https://arxiv.org/abs/2001.09386.
World Wide Web Conf. (WWW’2020).
<a id="2">[2]</a> Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina
Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding": https://arxiv.org/abs/1810.04805.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common NHNet/Bert2Bert configuration."""
from typing import List, Text
import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class BERT2BERTConfig(base_config.Config):
"""High-level configurations for BERT2BERT model.
These include parameters that are not directly related to the experiment,
e.g. encoder, decoder, prediction, training, etc.
"""
vocab_size: int = 30522
hidden_size: int = 768
num_hidden_layers: int = 12
num_attention_heads: int = 12
intermediate_size: int = 3072
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
decoder_intermediate_size: int = 3072
num_decoder_attn_heads: int = 12
num_decoder_layers: int = 12
label_smoothing: float = 0.1
learning_rate: float = 0.05
learning_rate_warmup_steps: int = 20000
optimizer: str = "Adam"
adam_beta1: float = 0.9
adam_beta2: float = 0.997
adam_epsilon: float = 1e-09
# predict params
beam_size: int = 5
alpha: float = 0.6
initializer_gain: float = 1.0
use_cache: bool = True
# input params
input_sharding: bool = False
input_data_not_padded: bool = False
pad_token_id: int = 0
end_token_id: int = 102
start_token_id: int = 101
@dataclasses.dataclass
class NHNetConfig(BERT2BERTConfig):
"""High-level configurations for NHNet model.
These include parameters that are not directly related to the experiment,
e.g. encoder, decoder, prediction, training, etc.
"""
multi_channel_cross_attention: bool = True
passage_list: List[Text] = dataclasses.field(
default_factory=lambda: [chr(ord("b") + i) for i in range(5)])
# Initialization method.
# If init_from_bert2bert is false, we assume the checkpoint is from BERT
# pretraining and only encoder and self-attention variables are initialized.
init_from_bert2bert: bool = True
UNITTEST_CONFIG = {
"attention_probs_dropout_prob": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 16,
"initializer_range": 0.02,
"intermediate_size": 32,
"max_position_embeddings": 128,
"num_attention_heads": 2,
"num_hidden_layers": 1,
"type_vocab_size": 2,
"vocab_size": 30522,
"initializer_gain": 1.0,
"decoder_intermediate_size": 32,
"num_decoder_attn_heads": 2,
"num_decoder_layers": 1,
"use_cache": True,
"input_data_not_padded": False,
"pad_token_id": 0,
"end_token_id": 102,
"start_token_id": 101,
}
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for configs."""
import tensorflow as tf
from official.nlp.nhnet import configs
BERT2BERT_CONFIG = {
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
# model params
"decoder_intermediate_size": 3072,
"num_decoder_attn_heads": 12,
"num_decoder_layers": 12,
# training params
"label_smoothing": 0.1,
"learning_rate": 0.05,
"learning_rate_warmup_steps": 20000,
"optimizer": "Adam",
"adam_beta1": 0.9,
"adam_beta2": 0.997,
"adam_epsilon": 1e-09,
# predict params
"beam_size": 5,
"alpha": 0.6,
"initializer_gain": 1.0,
"use_cache": True,
# input params
"input_sharding": False,
"input_data_not_padded": False,
"pad_token_id": 0,
"end_token_id": 102,
"start_token_id": 101,
}
NHNET_CONFIG = {
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
# model params
"decoder_intermediate_size": 3072,
"num_decoder_attn_heads": 12,
"num_decoder_layers": 12,
"multi_channel_cross_attention": True,
# training params
"label_smoothing": 0.1,
"learning_rate": 0.05,
"learning_rate_warmup_steps": 20000,
"optimizer": "Adam",
"adam_beta1": 0.9,
"adam_beta2": 0.997,
"adam_epsilon": 1e-09,
# predict params
"beam_size": 5,
"alpha": 0.6,
"initializer_gain": 1.0,
"use_cache": True,
# input params
"passage_list": ["b", "c", "d", "e", "f"],
"input_sharding": False,
"input_data_not_padded": False,
"pad_token_id": 0,
"end_token_id": 102,
"start_token_id": 101,
"init_from_bert2bert": True,
}
class ConfigsTest(tf.test.TestCase):
def test_configs(self):
cfg = configs.BERT2BERTConfig()
cfg.validate()
self.assertEqual(cfg.as_dict(), BERT2BERT_CONFIG)
def test_nhnet_config(self):
cfg = configs.NHNetConfig()
cfg.validate()
self.assertEqual(cfg.as_dict(), NHNET_CONFIG)
if __name__ == "__main__":
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer decoder that mimics a BERT encoder, to load BERT checkpoints."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.nhnet import multi_channel_attention
from official.nlp.transformer import model_utils as transformer_utils
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
"""
def __init__(self,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
multi_channel_cross_attention=False,
**kwargs):
super(TransformerDecoderBlock, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf_utils.get_activation(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
self._bias_initializer = tf.keras.initializers.get("zeros")
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = layers.MultiHeadAttention
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (self.hidden_size, self.num_attention_heads))
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
def build(self, unused_input_shapes):
# Self attention.
self.self_attention = layers.CachedAttention(
num_heads=self.num_attention_heads,
head_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="self_attention")
self.self_attention_output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="self_attention_output")
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", axis=-1, epsilon=1e-12))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
head_size=self.attention_head_size,
dropout_rate=self.attention_probs_dropout_prob,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
self.encdec_attention_output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="attention/encdec_output")
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection.
self.intermediate_dense = layers.DenseEinsum(
output_shape=self.intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = layers.DenseEinsum(
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
super(TransformerDecoderBlock, self).build(unused_input_shapes)
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_output_dense,
self.self_attention_layer_norm, self.intermediate_dense,
self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
if cache is None:
self_attention_inputs = [input_tensor, input_tensor, self_attention_mask]
else:
self_attention_inputs = [
input_tensor, input_tensor, self_attention_mask, cache
]
self_attention_output, cache = self.self_attention(
self_attention_inputs, decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_output_dense(
self_attention_output)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory, attention_mask]
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs)
attention_output = self.encdec_attention_output_dense(attention_output)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder stack."""
def __init__(self,
num_hidden_layers=12,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
intermediate_activation="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
attend_to_last_layer=True,
multi_channel_cross_attention=False,
**kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf_utils.get_activation(
intermediate_activation)
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.attend_to_last_layer = attend_to_last_layer
self.multi_channel_cross_attention = multi_channel_cross_attention
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(
TransformerDecoderBlock(
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
intermediate_activation=self.intermediate_activation,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
multi_channel_cross_attention=self.multi_channel_cross_attention,
name=("layer_%d" % i)))
super(TransformerDecoder, self).build(unused_input_shapes)
def call(self, inputs, cache=None, decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
inputs: A dictionary of inputs. `decoder_inputs` is a tf.int32 tensor for
input ids. `encoder_outputs` is a list of tensors with shape
[batch_size, input_length, hidden_size]. `self_attention_mask` is the
bias for decoder self-attention layer. [1, 1, target_length,
target_length]. `attention_mask` is the bias for encoder-decoder
attention layer, [batch_size, 1, 1, input_length].
cache: A dictionary of cache tensors, including key & value attentions.
decode_loop_step: an integer to indicate the step inside a decoding loop.
Returns:
Output of decoder layer stack.
float32 tensor with shape [batch_size, target_length, hidden_size]
"""
decoder_inputs = inputs["decoder_inputs"]
encoder_outputs = inputs["encoder_outputs"]
self_attention_mask = inputs["self_attention_mask"]
attention_mask = inputs["attention_mask"]
decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3)
batch_size = decoder_shape[0]
decoder_length = decoder_shape[1]
def _to_bert_self_attention_mask(matrix):
"""[1, 1, target_len, target_len] -> [bs, target_len, target_len]."""
matrix = tf.squeeze(matrix, axis=[1])
matrix = tf.tile(matrix, [batch_size, 1, 1])
return matrix
def _to_bert_encdec_attention_mask(matrix):
"""[bs, 1, 1, input_len] -> [bs, target_len, input_len]."""
if self.multi_channel_cross_attention:
matrix = tf.expand_dims(matrix, axis=2)
matrix = tf.tile(matrix, [1, 1, decoder_length, 1])
else:
matrix = tf.squeeze(matrix, axis=[1])
matrix = tf.tile(matrix, [1, decoder_length, 1])
return matrix
attention_mask = _to_bert_encdec_attention_mask(attention_mask)
self_attention_mask = _to_bert_self_attention_mask(self_attention_mask)
output_tensor = decoder_inputs
for layer_idx in range(self.num_hidden_layers):
if self.attend_to_last_layer:
memory = encoder_outputs[-1]
else:
memory = encoder_outputs[layer_idx]
if self.multi_channel_cross_attention:
transformer_inputs = [
output_tensor, memory, attention_mask, self_attention_mask,
inputs["doc_attention_probs"]
]
else:
transformer_inputs = [
output_tensor, memory, attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
return output_tensor, cache
def get_attention_bias(input_tensor,
bias_type,
padding_value=0,
max_length=None):
"""A helper function to get various attention bias tensors."""
if bias_type not in ("single_cross", "multi_cross", "decoder_self"):
raise ValueError("Invalid attention bias type: %s" % bias_type)
if bias_type == "single_cross":
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
bias = transformer_utils.get_padding_bias(
input_tensor, padding_value=padding_value)
elif bias_type == "multi_cross":
length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2]
padding = transformer_utils.get_padding(
input_tensor, padding_value=padding_value)
bias = padding * -1e9
else:
if max_length is not None:
length = max_length
else:
length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1]
bias = transformer_utils.get_decoder_self_attention_bias(length)
return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias))
class AttentionBias(tf.keras.layers.Layer):
def __init__(self, bias_type, **kwargs):
super(AttentionBias, self).__init__(**kwargs)
self.bias_type = bias_type
def call(self, inputs):
return get_attention_bias(inputs, self.bias_type)
class EmbeddingPostprocessor(tf.keras.layers.Layer):
"""Performs various post-processing on a word embedding tensor."""
def __init__(self,
use_type_embeddings=False,
token_type_vocab_size=None,
use_position_embeddings=True,
max_position_embeddings=512,
dropout_prob=0.0,
initializer_range=0.02,
initializer=None,
**kwargs):
super(EmbeddingPostprocessor, self).__init__(**kwargs)
self.use_type_embeddings = use_type_embeddings
self.token_type_vocab_size = token_type_vocab_size
self.use_position_embeddings = use_position_embeddings
self.max_position_embeddings = max_position_embeddings
self.dropout_prob = dropout_prob
self.initializer_range = initializer_range
if not initializer:
self.initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_range)
else:
self.initializer = initializer
if self.use_type_embeddings and not self.token_type_vocab_size:
raise ValueError("If `use_type_embeddings` is True, then "
"`token_type_vocab_size` must be specified.")
def build(self, input_shapes):
"""Implements build() for the layer."""
(word_embeddings_shape, _) = input_shapes
width = word_embeddings_shape.as_list()[-1]
self.type_embeddings = None
if self.use_type_embeddings:
self.type_embeddings = self.add_weight(
"type_embeddings",
shape=[self.token_type_vocab_size, width],
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
dtype=self.dtype)
self.position_embeddings = None
if self.use_position_embeddings:
self.position_embeddings = self.add_weight(
"position_embeddings",
shape=[self.max_position_embeddings, width],
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.initializer_range),
dtype=self.dtype)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
self.output_dropout = tf.keras.layers.Dropout(
rate=self.dropout_prob, dtype=tf.float32)
super(EmbeddingPostprocessor, self).build(input_shapes)
def __call__(self, word_embeddings, token_type_ids=None, **kwargs):
inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids])
return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
unpacked_inputs = tf_utils.unpack_inputs(inputs)
word_embeddings = unpacked_inputs[0]
token_type_ids = unpacked_inputs[1]
input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]
output = word_embeddings
if self.use_type_embeddings:
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
token_type_embeddings = tf.gather(self.type_embeddings,
flat_token_type_ids)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings
if self.use_position_embeddings:
position_embeddings = tf.expand_dims(
tf.slice(self.position_embeddings, [0, 0], [seq_length, width]),
axis=0)
output += position_embeddings
output = self.output_layer_norm(output)
output = self.output_dropout(output)
return output
class Decoder(tf.keras.layers.Layer):
"""The decoder network which can reuse encoder embeddings for target."""
def __init__(self, config, embedding_lookup=None, **kwargs):
super(Decoder, self).__init__(**kwargs)
self.config = config
# Shares vocabulary embedding.
self.embedding_lookup = None
if embedding_lookup:
self.embedding_lookup = embedding_lookup
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if self.embedding_lookup is None:
self.embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self.config.vocab_size,
embedding_width=self.config.hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.config.initializer_range),
name="target_embeddings")
self.embedding_postprocessor = EmbeddingPostprocessor(
use_type_embeddings=False,
use_position_embeddings=True,
max_position_embeddings=self.config.max_position_embeddings,
dropout_prob=self.config.hidden_dropout_prob,
initializer=tf.keras.initializers.VarianceScaling(
scale=self.config.initializer_gain,
mode="fan_avg",
distribution="uniform"),
name="embedding_postprocessor")
# Decoder can use a different intermediate size.
self.multi_channel_cross_attention = self.config.get(
"multi_channel_cross_attention", False)
self.decoder = TransformerDecoder(
num_hidden_layers=self.config.num_decoder_layers,
hidden_size=self.config.hidden_size,
num_attention_heads=self.config.num_decoder_attn_heads,
intermediate_size=self.config.decoder_intermediate_size,
intermediate_activation=self.config.hidden_act,
hidden_dropout_prob=self.config.hidden_dropout_prob,
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
multi_channel_cross_attention=self.multi_channel_cross_attention,
name="decoder")
super(Decoder, self).build(unused_input_shapes)
def _decoding_step_time_signal(self, target_embeds, decode_loop_step):
"""Applies time signal (positional embeddings) for decoded embeddings."""
# TODO(hongkuny): migrate to keras bert and design a module to handle this.
output = target_embeds
if self.embedding_postprocessor.use_position_embeddings:
position_embeddings = tf.gather(
self.embedding_postprocessor.position_embeddings, [decode_loop_step])
# Broadcasts to all sequences inside a batch.
output += position_embeddings
output = self.embedding_postprocessor.output_layer_norm(output)
output = self.embedding_postprocessor.output_dropout(output)
return output
def call(self,
inputs,
cache=None,
decode_loop_step=None,
padded_decode=False):
"""Implements call() for the layer.
Args:
inputs: a list of input tensors.
cache: A dictionary of cache tensors, including key & value attentions.
Due to the limit of keras, we uses the side effect to update cache and
states of tensors will be mutated.
decode_loop_step: an integer to indicate the step inside a decoding loop.
padded_decode: a boolean indicates if the pass is for padded decoding.
Returns:
Decoder output tensors.
"""
attention_bias = inputs["attention_bias"]
target_ids = inputs["target_ids"]
all_encoder_outputs = inputs["all_encoder_outputs"]
self_attention_bias = inputs["self_attention_bias"]
if not isinstance(all_encoder_outputs, list):
all_encoder_outputs = [all_encoder_outputs]
target_embeds = self.embedding_lookup(target_ids)
if decode_loop_step is None:
target_embeds = self.embedding_postprocessor(target_embeds)
else:
target_embeds = self._decoding_step_time_signal(target_embeds,
decode_loop_step)
decoder_inputs = dict(
decoder_inputs=target_embeds,
encoder_outputs=all_encoder_outputs,
self_attention_mask=self_attention_bias,
attention_mask=attention_bias)
if self.multi_channel_cross_attention:
decoder_inputs["doc_attention_probs"] = inputs["doc_attention_probs"]
decode_outputs, cache = self.decoder(
decoder_inputs, cache, decode_loop_step if padded_decode else None)
return decode_outputs
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nlp.nhnet.decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.nhnet import configs
from official.nlp.nhnet import decoder
from official.nlp.nhnet import utils
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
"key":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
"value":
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
class DecoderTest(tf.test.TestCase):
def setUp(self):
super(DecoderTest, self).setUp()
self._config = utils.get_test_params()
def test_transformer_decoder(self):
decoder_block = decoder.TransformerDecoder(
num_hidden_layers=self._config.num_hidden_layers,
hidden_size=self._config.hidden_size,
num_attention_heads=self._config.num_attention_heads,
intermediate_size=self._config.intermediate_size,
intermediate_activation=self._config.hidden_act,
hidden_dropout_prob=self._config.hidden_dropout_prob,
attention_probs_dropout_prob=self._config.attention_probs_dropout_prob,
initializer_range=self._config.initializer_range)
decoder_block.build(None)
self.assertEqual(len(decoder_block.layers), self._config.num_hidden_layers)
def test_decoder_block_with_cache(self):
decoder_block = decoder.TransformerDecoderBlock(
hidden_size=self._config.hidden_size,
num_attention_heads=self._config.num_attention_heads,
intermediate_size=self._config.intermediate_size,
intermediate_activation=self._config.hidden_act,
hidden_dropout_prob=self._config.hidden_dropout_prob,
attention_probs_dropout_prob=self._config.attention_probs_dropout_prob,
initializer_range=self._config.initializer_range)
# Forward path.
dummy_tensor = tf.zeros([2, 4, self._config.hidden_size], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(
2, 0, self._config.num_attention_heads,
self._config.hidden_size // self._config.num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, self._config.hidden_size))
self.assertEqual(cache["value"].shape, (2, 4, 2, 8))
def test_bert_decoder(self):
seq_length = 10
encoder_input_ids = tf.keras.layers.Input(
shape=(seq_length,), name="encoder_input_ids", dtype=tf.int32)
target_ids = tf.keras.layers.Input(
shape=(seq_length,), name="target_ids", dtype=tf.int32)
encoder_outputs = tf.keras.layers.Input(
shape=(seq_length, self._config.hidden_size),
name="all_encoder_outputs",
dtype=tf.float32)
embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self._config.vocab_size,
embedding_width=self._config.hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self._config.initializer_range),
name="word_embeddings")
cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")(
encoder_input_ids)
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
target_ids)
inputs = dict(
attention_bias=cross_attention_bias,
self_attention_bias=self_attention_bias,
target_ids=target_ids,
all_encoder_outputs=encoder_outputs)
decoder_layer = decoder.Decoder(self._config, embedding_lookup)
outputs = decoder_layer(inputs)
model_inputs = dict(
encoder_input_ids=encoder_input_ids,
target_ids=target_ids,
all_encoder_outputs=encoder_outputs)
model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test")
self.assertLen(decoder_layer.trainable_weights, 30)
# Forward path.
fake_inputs = {
"encoder_input_ids": np.zeros((2, 10), dtype=np.int32),
"target_ids": np.zeros((2, 10), dtype=np.int32),
"all_encoder_outputs": np.zeros((2, 10, 16), dtype=np.float32),
}
output_tensor = model(fake_inputs)
self.assertEqual(output_tensor.shape, (2, 10, 16))
def test_multi_doc_decoder(self):
self._config = utils.get_test_params(cls=configs.NHNetConfig)
seq_length = 10
num_docs = 5
encoder_input_ids = tf.keras.layers.Input(
shape=(num_docs, seq_length), name="encoder_input_ids", dtype=tf.int32)
target_ids = tf.keras.layers.Input(
shape=(seq_length,), name="target_ids", dtype=tf.int32)
encoder_outputs = tf.keras.layers.Input(
shape=(num_docs, seq_length, self._config.hidden_size),
name="all_encoder_outputs",
dtype=tf.float32)
embedding_lookup = layers.OnDeviceEmbedding(
vocab_size=self._config.vocab_size,
embedding_width=self._config.hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self._config.initializer_range),
name="word_embeddings")
doc_attention_probs = tf.keras.layers.Input(
shape=(self._config.num_decoder_attn_heads, seq_length, num_docs),
name="doc_attention_probs",
dtype=tf.float32)
cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
encoder_input_ids)
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
target_ids)
inputs = dict(
attention_bias=cross_attention_bias,
self_attention_bias=self_attention_bias,
target_ids=target_ids,
all_encoder_outputs=encoder_outputs,
doc_attention_probs=doc_attention_probs)
decoder_layer = decoder.Decoder(self._config, embedding_lookup)
outputs = decoder_layer(inputs)
model_inputs = dict(
encoder_input_ids=encoder_input_ids,
target_ids=target_ids,
all_encoder_outputs=encoder_outputs,
doc_attention_probs=doc_attention_probs)
model = tf.keras.Model(inputs=model_inputs, outputs=outputs, name="test")
self.assertLen(decoder_layer.trainable_weights, 30)
# Forward path.
fake_inputs = {
"encoder_input_ids":
np.zeros((2, num_docs, seq_length), dtype=np.int32),
"target_ids":
np.zeros((2, seq_length), dtype=np.int32),
"all_encoder_outputs":
np.zeros((2, num_docs, seq_length, 16), dtype=np.float32),
"doc_attention_probs":
np.zeros(
(2, self._config.num_decoder_attn_heads, seq_length, num_docs),
dtype=np.float32)
}
output_tensor = model(fake_inputs)
self.assertEqual(output_tensor.shape, (2, seq_length, 16))
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation for Bert2Bert."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import logging
import numpy as np
import tensorflow as tf
from official.nlp.nhnet import input_pipeline
from official.nlp.nhnet import models
from official.nlp.transformer import metrics as metrics_v2
from official.nlp.transformer.utils import metrics
def rouge_l_fscore(logits, labels):
"""ROUGE scores computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
logits: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge_l_fscore: approx rouge-l f1 score.
"""
predictions = np.argmax(logits, axis=-1)
rouge_l_f_score = metrics.rouge_l_sentence_level(predictions, labels)
return rouge_l_f_score
def rouge_2_fscore(logits, labels):
"""ROUGE-2 F1 score computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
logits: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge2_fscore: approx rouge-2 f1 score.
"""
predictions = np.argmax(logits, axis=-1)
rouge_2_f_score = metrics.rouge_n(predictions, labels)
return rouge_2_f_score
def bleu_score(logits, labels):
"""Approximate BLEU score computation between labels and predictions.
An approximate BLEU scoring method since we do not glue word pieces or
decode the ids and tokenize the output. By default, we use ngram order of 4
and use brevity penalty. Also, this does not have beam search.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch-size, length_labels]
Returns:
bleu: int, approx bleu score
"""
predictions = np.argmax(logits, axis=-1)
bleu = metrics.compute_bleu(labels, predictions)
return bleu
def continuous_eval(strategy,
params,
model_type,
eval_file_pattern=None,
batch_size=4,
eval_steps=None,
model_dir=None,
timeout=3000):
"""Continuously evaluate checkpoints on testing data."""
test_dataset = input_pipeline.get_input_dataset(
eval_file_pattern,
batch_size=batch_size,
params=params,
is_training=False,
strategy=strategy)
with strategy.scope():
model = models.create_model(model_type, params)
metric_layer = metrics_v2.MetricLayer(params.vocab_size)
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/eval"))
global_step = tf.Variable(
0,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
model.global_step = global_step
@tf.function
def test_step(inputs):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
targets = models.remove_sos_from_seq(inputs["target_ids"],
params.pad_token_id)
# Using ground truth sequences as targets to calculate logits for accuracy
# and perplexity metrics.
logits, _, _ = model(inputs, training=False, mode="train")
metric_layer([logits, targets])
# Get logits from top beam search results for bleu and rouge metrics.
logits = model(inputs, training=False, mode="eval")
return targets, logits
outputs = strategy.run(_test_step_fn, args=(inputs,))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
metrics_and_funcs = [
(tf.keras.metrics.Mean("bleu", dtype=tf.float32), bleu_score),
(tf.keras.metrics.Mean("rouge_2_fscore",
dtype=tf.float32), rouge_2_fscore),
(tf.keras.metrics.Mean("rouge_l_fscore",
dtype=tf.float32), rouge_l_fscore),
]
eval_results = {}
for latest_checkpoint in tf.train.checkpoints_iterator(
model_dir, timeout=timeout):
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(latest_checkpoint).expect_partial()
logging.info("Loaded checkpoint %s", latest_checkpoint)
for i, inputs in enumerate(test_dataset):
if eval_steps and i >= eval_steps:
break
outputs = test_step(inputs)
for metric, func in metrics_and_funcs:
for targets, logits in zip(outputs[0], outputs[1]):
metric.update_state(func(logits.numpy(), targets.numpy()))
with eval_summary_writer.as_default():
step = model.global_step.numpy()
for metric, _ in metrics_and_funcs:
eval_results[metric.name] = metric.result().numpy().astype(float)
tf.summary.scalar(
metric.name,
eval_results[metric.name],
step=step)
for metric in metric_layer.metrics:
eval_results[metric.name] = metric.result().numpy().astype(float)
tf.summary.scalar(
metric.name,
eval_results[metric.name],
step=step)
logging.info("Step %d Metrics= %s", step, str(eval_results))
eval_summary_writer.flush()
# Resets metrics.
for metric, _ in metrics_and_funcs:
metric.reset_states()
for metric in metric_layer.metrics:
metric.reset_states()
return eval_results
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
def decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def process_singledoc_dataset(dataset, batch_size, params):
"""Parses and batches single-doc dataset."""
name_to_features = {
"input_ids_a": tf.io.FixedLenFeature([params.len_title], tf.int64),
"input_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
"input_mask_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
"segment_ids_b": tf.io.FixedLenFeature([params.len_passage], tf.int64),
}
decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
return {
"input_ids": record["input_ids_b"],
"input_mask": record["input_mask_b"],
"segment_ids": record["segment_ids_b"],
"target_ids": record["input_ids_a"],
}
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
def decode_sparse_record(record, name_to_features):
"""Decodes a sparse record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = tf.sparse.to_dense(t)
return example
def _filter_max_length(example, max_title_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.size(example["targets"]) <= max_title_length
def process_singledoc_transformer_dataset(dataset, batch_size, params):
"""Parses, batches and pads single-doc dataset."""
name_to_features = {
"inputs": tf.io.VarLenFeature(tf.int64),
"targets": tf.io.VarLenFeature(tf.int64),
}
decode_fn = lambda record: decode_sparse_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
input_ids = record["inputs"][:params.len_passage]
target_ids = record["targets"]
input_mask = tf.ones_like(input_ids)
segment_ids = tf.zeros_like(input_ids)
return {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
"target_ids": target_ids,
}
dataset = dataset.filter(lambda x: _filter_max_length(x, params.len_title))
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.padded_batch(
batch_size, {
"input_ids": [params.len_passage],
"input_mask": [params.len_passage],
"segment_ids": [params.len_passage],
"target_ids": [params.len_title],
},
padding_values={
"input_ids": params.pad_token_id,
"input_mask": 0,
"segment_ids": 0,
"target_ids": params.pad_token_id,
},
drop_remainder=True)
return dataset
def multidoc_parse_spec(params, training=True):
"""Gets the mutli-doc tf.Example parsing spec."""
len_p = params.len_passage
name_to_features = {}
feature_list = ["input_ids", "input_mask", "segment_ids"]
for idx in params.passage_list:
for feature in feature_list:
name_to_features["%s_%s" % (feature, idx)] = tf.io.FixedLenFeature(
[len_p], tf.int64)
if training:
# Cluster title.
name_to_features["input_ids_a"] = tf.io.FixedLenFeature([params.len_title],
tf.int64)
return name_to_features, feature_list
def process_multidoc_dataset(dataset, batch_size, params):
"""Parses, organizes and batches multi-doc dataset."""
name_to_features, feature_list = multidoc_parse_spec(params)
decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _select_data_from_record(record):
"""Filter out features to use for pretraining."""
features = {"target_ids": record["input_ids_a"]}
for feature in feature_list:
tensors = [record["%s_%s" % (feature, i)] for i in params.passage_list]
features[feature] = tf.stack(tensors)
return features
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
def create_dataset(file_paths,
batch_size,
params,
is_training=True,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for pretraining."""
dataset = tf.data.Dataset.list_files(file_paths, shuffle=is_training)
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
if not is_training or params.input_sharding:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
if is_training:
dataset = dataset.repeat()
# We set shuffle buffer to exactly match total number of
# training files to ensure that training data is well shuffled.
dataset = dataset.shuffle(len(file_paths))
# In parallel, create tf record dataset for each train files.
# cycle_length = 8 means that up to 8 files will be read and deserialized in
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave(
tf.data.TFRecordDataset,
cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(100)
if params.get("multi_channel_cross_attention", value=False):
dataset = process_multidoc_dataset(dataset, batch_size, params)
else:
if not params.input_data_not_padded:
dataset = process_singledoc_dataset(dataset, batch_size, params)
else:
dataset = process_singledoc_transformer_dataset(dataset, batch_size,
params)
dataset = dataset.prefetch(1024)
return dataset
def get_input_dataset(input_file_pattern,
batch_size,
params,
is_training,
strategy=None):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
input_files = []
for input_pattern in input_file_pattern.split(","):
input_files.extend(tf.io.gfile.glob(input_pattern))
return create_dataset(
input_files,
batch_size,
params,
is_training=is_training,
input_pipeline_context=ctx)
if use_dataset_fn:
return strategy.experimental_distribute_datasets_from_function(_dataset_fn)
else:
return strategy.experimental_distribute_dataset(_dataset_fn())
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tf.keras Models for NHNet."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import gin
import tensorflow as tf
from typing import Optional, Text
from official.modeling import tf_utils
from official.modeling.hyperparams import params_dict
from official.nlp.modeling import networks
from official.nlp.nhnet import configs
from official.nlp.nhnet import decoder
from official.nlp.nhnet import multi_channel_attention
from official.nlp.nhnet import utils
from official.nlp.transformer import beam_search
def embedding_linear(embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
with tf.name_scope("presoftmax_linear"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
def _add_sos_to_seq(seq, start_token_id):
"""Add a start sequence token while keeping seq length."""
batch_size = tf.shape(seq)[0]
seq_len = tf.shape(seq)[1]
sos_ids = tf.ones([batch_size], tf.int32) * start_token_id
targets = tf.concat([tf.expand_dims(sos_ids, axis=1), seq], axis=1)
targets = targets[:, :-1]
tf.assert_equal(tf.shape(targets), (batch_size, seq_len))
return targets
def remove_sos_from_seq(seq, pad_token_id):
"""Remove the start sequence token while keeping seq length."""
batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2)
# remove <s>
targets = seq[:, 1:]
# pad
pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id
targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1)
tf.assert_equal(tf.shape(targets), (batch_size, seq_len))
return targets
class Bert2Bert(tf.keras.Model):
"""Bert2Bert encoder decoder model for training."""
def __init__(self, params, bert_layer, decoder_layer, name=None):
super(Bert2Bert, self).__init__(name=name)
self.params = params
if not bert_layer.built:
raise ValueError("bert_layer should be built.")
if not decoder_layer.built:
raise ValueError("decoder_layer should be built.")
self.bert_layer = bert_layer
self.decoder_layer = decoder_layer
def get_config(self):
return {"params": self.params.as_dict()}
def get_decode_logits(self,
decoder_inputs,
ids,
decoder_self_attention_bias,
step,
cache=None):
if cache:
if self.params.get("padded_decode", False):
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, step, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
self_attention_bias = decoder_self_attention_bias[:, :, step:step +
1, :step + 1]
# Sets decoder input to the last generated IDs.
decoder_input = ids[:, -1:]
else:
self_attention_bias = decoder_self_attention_bias[:, :, :step + 1, :step +
1]
decoder_input = ids
decoder_inputs["target_ids"] = decoder_input
decoder_inputs["self_attention_bias"] = self_attention_bias
if cache:
decoder_outputs = self.decoder_layer(
decoder_inputs,
cache,
decode_loop_step=step,
padded_decode=self.params.get("padded_decode", False))
else:
decoder_outputs = self.decoder_layer(decoder_inputs)
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
decoder_outputs[:, -1:, :])
logits = tf.squeeze(logits, axis=[1])
return logits
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
# Max decode length should be smaller than the positional embedding max
# sequence length.
decoder_self_attention_bias = decoder.get_attention_bias(
input_tensor=None,
bias_type="decoder_self",
max_length=max_decode_length)
def _symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next candidate IDs.
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
i: Loop index
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape [batch_size * beam_size, vocab_size],
updated cache values)
"""
decoder_inputs = dict(
all_encoder_outputs=cache["all_encoder_outputs"],
attention_bias=cache["attention_bias"])
logits = self.get_decode_logits(
decoder_inputs,
ids,
decoder_self_attention_bias,
step=i,
cache=cache if self.params.use_cache else None)
return logits, cache
return _symbols_to_logits_fn
def train_decode(self, decode_outputs):
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
decode_outputs)
decode_output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
output_log_probs = tf.nn.log_softmax(logits, axis=-1)
return logits, decode_output_ids, output_log_probs
def predict_decode(self, start_token_ids, cache):
symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title)
# Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search(
symbols_to_logits_fn=symbols_to_logits_fn,
initial_ids=start_token_ids,
initial_cache=cache,
vocab_size=self.params.vocab_size,
beam_size=self.params.beam_size,
alpha=self.params.alpha,
max_decode_length=self.params.len_title,
padded_decode=self.params.get("padded_decode", False),
eos_id=self.params.end_token_id)
return decoded_ids, scores
def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids):
"""Returns the log probabilities for ids."""
target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id)
decoder_inputs["self_attention_bias"] = decoder.get_attention_bias(
target_ids, bias_type="decoder_self")
decoder_inputs["target_ids"] = target_ids
decoder_outputs = self.decoder_layer(decoder_inputs)
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings,
decoder_outputs)
return logits
def _init_cache(self, batch_size):
num_heads = self.params.num_decoder_attn_heads
dim_per_head = self.params.hidden_size // num_heads
init_decode_length = (
self.params.len_title if self.params.get("padded_decode", False) else 0)
cache = {}
for layer in range(self.params.num_decoder_layers):
cache[str(layer)] = {
"key":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=tf.float32),
"value":
tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head],
dtype=tf.float32)
}
return cache
def call(self, inputs, mode="train"):
"""Implements call().
Args:
inputs: a dictionary of tensors.
mode: string, an enum for mode, train/eval.
Returns:
logits, decode_output_ids, output_log_probs for training. top_decoded_ids
for eval.
"""
input_ids = inputs["input_ids"]
input_mask = inputs["input_mask"]
segment_ids = inputs["segment_ids"]
all_encoder_outputs, _ = self.bert_layer(
[input_ids, input_mask, segment_ids])
if mode not in ("train", "eval", "predict"):
raise ValueError("Invalid call mode: %s" % mode)
encoder_decoder_attention_bias = decoder.get_attention_bias(
input_ids,
bias_type="single_cross",
padding_value=self.params.pad_token_id)
if mode == "train":
self_attention_bias = decoder.get_attention_bias(
inputs["target_ids"], bias_type="decoder_self")
decoder_inputs = dict(
attention_bias=encoder_decoder_attention_bias,
all_encoder_outputs=all_encoder_outputs,
target_ids=inputs["target_ids"],
self_attention_bias=self_attention_bias)
decoder_outputs = self.decoder_layer(decoder_inputs)
return self.train_decode(decoder_outputs)
batch_size = tf.shape(input_ids)[0]
start_token_ids = tf.ones([batch_size],
tf.int32) * self.params.start_token_id
# Add encoder output and attention bias to the cache.
if self.params.use_cache:
cache = self._init_cache(batch_size)
else:
cache = {}
cache["all_encoder_outputs"] = all_encoder_outputs
cache["attention_bias"] = encoder_decoder_attention_bias
decoded_ids, scores = self.predict_decode(start_token_ids, cache)
if mode == "predict":
return decoded_ids[:, :self.params.beam_size,
1:], scores[:, :self.params.beam_size]
decoder_inputs = dict(
attention_bias=encoder_decoder_attention_bias,
all_encoder_outputs=all_encoder_outputs)
top_decoded_ids = decoded_ids[:, 0, 1:]
return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
class NHNet(Bert2Bert):
"""NHNet model which performs multi-doc decoding."""
def __init__(self, params, bert_layer, decoder_layer, name=None):
super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name)
self.doc_attention = multi_channel_attention.DocAttention(
num_heads=params.num_decoder_attn_heads,
head_size=params.hidden_size // params.num_decoder_attn_heads)
def _expand_doc_attention_probs(self, doc_attention_probs, target_length):
"""Expands doc attention probs to fit the decoding sequence length."""
doc_attention_probs = tf.expand_dims(
doc_attention_probs, axis=[1]) # [B, 1, A]
doc_attention_probs = tf.expand_dims(
doc_attention_probs, axis=[2]) # [B, 1, 1, A]
return tf.tile(doc_attention_probs,
[1, self.params.num_decoder_attn_heads, target_length, 1])
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
# Max decode length should be smaller than the positional embedding max
# sequence length.
decoder_self_attention_bias = decoder.get_attention_bias(
input_tensor=None,
bias_type="decoder_self",
max_length=max_decode_length)
def _symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next candidate IDs."""
if self.params.use_cache:
target_length = 1
else:
target_length = i + 1
decoder_inputs = dict(
doc_attention_probs=self._expand_doc_attention_probs(
cache["doc_attention_probs"], target_length),
all_encoder_outputs=cache["all_encoder_outputs"],
attention_bias=cache["attention_bias"])
logits = self.get_decode_logits(
decoder_inputs,
ids,
decoder_self_attention_bias,
step=i,
cache=cache if self.params.use_cache else None)
return logits, cache
return _symbols_to_logits_fn
def call(self, inputs, mode="training"):
input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3)
batch_size, num_docs, len_passage = (input_shape[0], input_shape[1],
input_shape[2])
input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage])
input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage])
segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage])
all_encoder_outputs, _ = self.bert_layer(
[input_ids, input_mask, segment_ids])
encoder_outputs = tf.reshape(
all_encoder_outputs[-1],
[batch_size, num_docs, len_passage, self.params.hidden_size])
doc_attention_mask = tf.reshape(
tf.cast(
tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2,
tf.int32), [batch_size, num_docs])
doc_attention_probs = self.doc_attention(encoder_outputs,
doc_attention_mask)
encoder_decoder_attention_bias = decoder.get_attention_bias(
inputs["input_ids"],
bias_type="multi_cross",
padding_value=self.params.pad_token_id)
if mode == "train":
target_length = tf_utils.get_shape_list(
inputs["target_ids"], expected_rank=2)[1]
doc_attention_probs = self._expand_doc_attention_probs(
doc_attention_probs, target_length)
self_attention_bias = decoder.get_attention_bias(
inputs["target_ids"], bias_type="decoder_self")
decoder_inputs = dict(
attention_bias=encoder_decoder_attention_bias,
self_attention_bias=self_attention_bias,
target_ids=inputs["target_ids"],
all_encoder_outputs=encoder_outputs,
doc_attention_probs=doc_attention_probs)
decoder_outputs = self.decoder_layer(decoder_inputs)
return self.train_decode(decoder_outputs)
# Adds encoder output and attention bias to the cache.
if self.params.use_cache:
cache = self._init_cache(batch_size)
else:
cache = {}
cache["all_encoder_outputs"] = [encoder_outputs]
cache["attention_bias"] = encoder_decoder_attention_bias
cache["doc_attention_probs"] = doc_attention_probs
start_token_ids = tf.ones([batch_size],
tf.int32) * self.params.start_token_id
decoded_ids, scores = self.predict_decode(start_token_ids, cache)
if mode == "predict":
return decoded_ids[:, :self.params.beam_size,
1:], scores[:, :self.params.beam_size]
top_decoded_ids = decoded_ids[:, 0, 1:]
target_length = tf_utils.get_shape_list(top_decoded_ids)[-1]
decoder_inputs = dict(
attention_bias=encoder_decoder_attention_bias,
all_encoder_outputs=[encoder_outputs],
doc_attention_probs=self._expand_doc_attention_probs(
doc_attention_probs, target_length))
return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids)
def get_bert2bert_layers(params: configs.BERT2BERTConfig):
"""Creates a Bert2Bert stem model and returns Bert encoder/decoder.
We use funtional-style to create stem model because we need to make all layers
built to restore variables in a customized way. The layers are called with
placeholder inputs to make them fully built.
Args:
params: ParamsDict.
Returns:
two keras Layers, bert_model_layer and decoder_layer
"""
input_ids = tf.keras.layers.Input(
shape=(None,), name="input_ids", dtype=tf.int32)
input_mask = tf.keras.layers.Input(
shape=(None,), name="input_mask", dtype=tf.int32)
segment_ids = tf.keras.layers.Input(
shape=(None,), name="segment_ids", dtype=tf.int32)
target_ids = tf.keras.layers.Input(
shape=(None,), name="target_ids", dtype=tf.int32)
bert_config = utils.get_bert_config_from_params(params)
bert_model_layer = networks.TransformerEncoder(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=None,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
return_all_encoder_outputs=True,
name="bert_encoder")
all_encoder_outputs, _ = bert_model_layer(
[input_ids, input_mask, segment_ids])
# pylint: disable=protected-access
decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
# pylint: enable=protected-access
cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")(
input_ids)
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
target_ids)
decoder_inputs = dict(
attention_bias=cross_attention_bias,
self_attention_bias=self_attention_bias,
target_ids=target_ids,
all_encoder_outputs=all_encoder_outputs)
_ = decoder_layer(decoder_inputs)
return bert_model_layer, decoder_layer
def get_nhnet_layers(params: configs.NHNetConfig):
"""Creates a Mult-doc encoder/decoder.
Args:
params: ParamsDict.
Returns:
two keras Layers, bert_model_layer and decoder_layer
"""
input_ids = tf.keras.layers.Input(
shape=(None,), name="input_ids", dtype=tf.int32)
input_mask = tf.keras.layers.Input(
shape=(None,), name="input_mask", dtype=tf.int32)
segment_ids = tf.keras.layers.Input(
shape=(None,), name="segment_ids", dtype=tf.int32)
bert_config = utils.get_bert_config_from_params(params)
bert_model_layer = networks.TransformerEncoder(
vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=None,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
return_all_encoder_outputs=True,
name="bert_encoder")
bert_model_layer([input_ids, input_mask, segment_ids])
input_ids = tf.keras.layers.Input(
shape=(None, None), name="input_ids", dtype=tf.int32)
all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size),
dtype=tf.float32)
target_ids = tf.keras.layers.Input(
shape=(None,), name="target_ids", dtype=tf.int32)
doc_attention_probs = tf.keras.layers.Input(
(params.num_decoder_attn_heads, None, None), dtype=tf.float32)
# pylint: disable=protected-access
decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer)
# pylint: enable=protected-access
cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")(
input_ids)
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")(
target_ids)
decoder_inputs = dict(
attention_bias=cross_attention_bias,
self_attention_bias=self_attention_bias,
target_ids=target_ids,
all_encoder_outputs=all_encoder_outputs,
doc_attention_probs=doc_attention_probs)
_ = decoder_layer(decoder_inputs)
return bert_model_layer, decoder_layer
def create_transformer_model(params,
init_checkpoint: Optional[Text] = None
) -> tf.keras.Model:
"""A helper to create Transformer model."""
bert_layer, decoder_layer = get_bert2bert_layers(params=params)
model = Bert2Bert(
params=params,
bert_layer=bert_layer,
decoder_layer=decoder_layer,
name="transformer")
if init_checkpoint:
logging.info(
"Checkpoint file %s found and restoring from "
"initial checkpoint.", init_checkpoint)
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(init_checkpoint).expect_partial()
return model
def create_bert2bert_model(
params: configs.BERT2BERTConfig,
cls=Bert2Bert,
init_checkpoint: Optional[Text] = None) -> tf.keras.Model:
"""A helper to create Bert2Bert model."""
bert_layer, decoder_layer = get_bert2bert_layers(params=params)
if init_checkpoint:
utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer,
init_checkpoint)
return cls(
params=params,
bert_layer=bert_layer,
decoder_layer=decoder_layer,
name="bert2bert")
def create_nhnet_model(
params: configs.NHNetConfig,
cls=NHNet,
init_checkpoint: Optional[Text] = None) -> tf.keras.Model:
"""A helper to create NHNet model."""
bert_layer, decoder_layer = get_nhnet_layers(params=params)
model = cls(
params=params,
bert_layer=bert_layer,
decoder_layer=decoder_layer,
name="nhnet")
if init_checkpoint:
logging.info(
"Checkpoint file %s found and restoring from "
"initial checkpoint.", init_checkpoint)
if params.init_from_bert2bert:
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(init_checkpoint).assert_existing_objects_matched()
else:
utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer,
init_checkpoint)
return model
@gin.configurable
def get_model_params(model: Optional[Text] = "bert2bert",
config_class=None) -> params_dict.ParamsDict:
"""Helper function to convert config file to ParamsDict."""
if model == "bert2bert":
return configs.BERT2BERTConfig()
elif model == "nhnet":
return configs.NHNetConfig()
elif config_class:
return config_class()
else:
raise KeyError("The model type is not defined: %s" % model)
@gin.configurable
def create_model(model_type: Text,
params,
init_checkpoint: Optional[Text] = None):
"""A factory function to create different types of models."""
if model_type == "bert2bert":
return create_bert2bert_model(params, init_checkpoint=init_checkpoint)
elif model_type == "nhnet":
return create_nhnet_model(params, init_checkpoint=init_checkpoint)
elif "transformer" in model_type:
return create_transformer_model(
params, init_checkpoint=init_checkpoint)
else:
raise KeyError("The model type is not defined: %s" % model_type)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multi-channel decoder."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
class DocAttention(tf.keras.layers.Layer):
"""Documents Attention layer."""
def __init__(self,
num_heads,
head_size,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(DocAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes):
self._query_dense = layers.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_query")
self._key_dense = layers.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_key")
super(DocAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
cls_embeddings = encoder_outputs[:, :, 0, :]
key = self._key_dense(cls_embeddings)
query = self._query_dense(cls_embeddings)
doc_attention_mask = tf.cast(doc_attention_mask, tf.float32)
key = tf.einsum("BANH,BA->BANH", key, doc_attention_mask)
query = tf.einsum("BANH,BA->BANH", query, doc_attention_mask)
attention_matrix = tf.einsum("BXNH,BYNH->BNXY", query, key)
mask = tf.ones([num_docs, num_docs])
mask = tf.linalg.set_diag(mask, tf.zeros(num_docs))
attention_matrix = tf.einsum("BNXY,XY->BNXY", attention_matrix, mask)
doc_attention_probs = tf.einsum("BNAY->BNA", attention_matrix)
doc_attention_probs = tf.einsum("BNA->BA", doc_attention_probs)
infadder = (1.0 - doc_attention_mask) * -100000.0
return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(layers.MultiHeadAttention):
"""Multi-channel Attention layer."""
def __init__(self, num_heads, head_size, **kwargs):
super(MultiChannelAttention, self).__init__(num_heads, head_size, **kwargs)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2])
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError("Layer %s must have 4 input tensors." % self.name)
from_tensor_shape = tf.TensorShape(input_shape[0])
batch = from_tensor_shape[0]
from_tensor_length = from_tensor_shape[1]
return tf.TensorShape(
(batch, from_tensor_length, self._num_heads, self._head_size))
def call(self, inputs):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2]
doc_attention_probs = inputs[3]
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# A = num_docs (number of docs)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BATNH,BFNH->BANFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, A, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor)
return tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, context_layer)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nlp.nhnet.multi_channel_attention."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.nlp.nhnet import multi_channel_attention
class MultiChannelAttentionTest(tf.test.TestCase):
def test_doc_attention(self):
num_heads = 2
doc_attention = multi_channel_attention.DocAttention(num_heads, head_size=8)
num_docs = 3
inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
doc_mask = np.zeros((2, num_docs), dtype=np.float32)
outputs = doc_attention(inputs, doc_mask)
self.assertEqual(outputs.shape, (2, num_docs))
def test_multi_channel_attention(self):
num_heads = 2
num_docs = 5
attention_layer = multi_channel_attention.MultiChannelAttention(
num_heads, head_size=2)
from_data = 10 * np.random.random_sample((3, 4, 8))
to_data = 10 * np.random.random_sample((3, num_docs, 2, 8))
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, mask_data, doc_probs])
self.assertEqual(outputs.shape, (3, 4, num_heads, 2))
if __name__ == "__main__":
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer and learning rate scheduler."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.modeling.hyperparams import params_dict
class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule."""
def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
hidden_size: An integer, the model dimension in the hidden layers.
warmup_steps: An integer, the number of steps required for linear warmup.
"""
super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay.
Args:
global_step: An integer, the current global step used for learning rate
calculation.
Returns:
A float, the learning rate needs to be used for current global step.
"""
with tf.name_scope('learning_rate_schedule'):
global_step = tf.cast(global_step, tf.float32)
learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
# Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
'initial_learning_rate': self.initial_learning_rate,
'hidden_size': self.hidden_size,
'warmup_steps': self.warmup_steps,
}
def create_optimizer(params: params_dict.ParamsDict):
"""Creates optimizer."""
lr_schedule = LearningRateSchedule(
params.learning_rate,
params.hidden_size,
params.learning_rate_warmup_steps)
return tf.keras.optimizers.Adam(
learning_rate=lr_schedule,
beta_1=params.adam_beta1,
beta_2=params.adam_beta2,
epsilon=params.adam_epsilon)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Processes crawled content from news URLs by generating tfrecords."""
import os
from absl import app
from absl import flags
from official.nlp.nhnet import raw_data_processor
FLAGS = flags.FLAGS
flags.DEFINE_string("crawled_articles", "/tmp/nhnet/",
"Folder path to the crawled articles using news-please.")
flags.DEFINE_string("vocab", None, "Filepath of the BERT vocabulary.")
flags.DEFINE_bool("do_lower_case", True,
"Whether the vocabulary is uncased or not.")
flags.DEFINE_integer("len_title", 15,
"Maximum number of tokens in story headline.")
flags.DEFINE_integer("len_passage", 200,
"Maximum number of tokens in article passage.")
flags.DEFINE_integer("max_num_articles", 5,
"Maximum number of articles in a story.")
flags.DEFINE_bool("include_article_title_in_passage", False,
"Whether to include article title in article passage.")
flags.DEFINE_string("data_folder", None,
"Folder path to the downloaded data folder (output).")
flags.DEFINE_integer("num_tfrecords_shards", 20,
"Number of shards for train/valid/test.")
def transform_as_tfrecords(data_processor, filename):
"""Transforms story from json to tfrecord (sharded).
Args:
data_processor: Instance of RawDataProcessor.
filename: 'train', 'valid', or 'test'.
"""
print("Transforming json to tfrecord for %s..." % filename)
story_filepath = os.path.join(FLAGS.data_folder, filename + ".json")
output_folder = os.path.join(FLAGS.data_folder, "processed")
os.makedirs(output_folder, exist_ok=True)
output_filepaths = []
for i in range(FLAGS.num_tfrecords_shards):
output_filepaths.append(
os.path.join(
output_folder, "%s.tfrecord-%.5d-of-%.5d" %
(filename, i, FLAGS.num_tfrecords_shards)))
(total_num_examples,
generated_num_examples) = data_processor.generate_examples(
story_filepath, output_filepaths)
print("For %s, %d examples have been generated from %d stories in json." %
(filename, generated_num_examples, total_num_examples))
def main(_):
if not FLAGS.data_folder:
raise ValueError("data_folder must be set as the downloaded folder path.")
if not FLAGS.vocab:
raise ValueError("vocab must be set as the filepath of BERT vocabulary.")
data_processor = raw_data_processor.RawDataProcessor(
vocab=FLAGS.vocab,
do_lower_case=FLAGS.do_lower_case,
len_title=FLAGS.len_title,
len_passage=FLAGS.len_passage,
max_num_articles=FLAGS.max_num_articles,
include_article_title_in_passage=FLAGS.include_article_title_in_passage,
include_text_snippet_in_example=True)
print("Loading crawled articles...")
num_articles = data_processor.read_crawled_articles(FLAGS.crawled_articles)
print("Total number of articles loaded: %d" % num_articles)
print()
transform_as_tfrecords(data_processor, "train")
transform_as_tfrecords(data_processor, "valid")
transform_as_tfrecords(data_processor, "test")
if __name__ == "__main__":
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for processing crawled content and generating tfrecords."""
import collections
import json
import multiprocessing
import os
import urllib.parse
import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
class RawDataProcessor(object):
"""Data converter for story examples."""
def __init__(self,
vocab: str,
do_lower_case: bool,
len_title: int = 15,
len_passage: int = 200,
max_num_articles: int = 5,
include_article_title_in_passage: bool = False,
include_text_snippet_in_example: bool = False):
"""Constructs a RawDataProcessor.
Args:
vocab: Filepath of the BERT vocabulary.
do_lower_case: Whether the vocabulary is uncased or not.
len_title: Maximum number of tokens in story headline.
len_passage: Maximum number of tokens in article passage.
max_num_articles: Maximum number of articles in a story.
include_article_title_in_passage: Whether to include article title in
article passage.
include_text_snippet_in_example: Whether to include text snippet
(headline and article content) in generated tensorflow Examples, for
debug usage. If include_article_title_in_passage=True, title and body
will be separated by [SEP].
"""
self.articles = dict()
self.tokenizer = tokenization.FullTokenizer(
vocab, do_lower_case=do_lower_case, split_on_punc=False)
self.len_title = len_title
self.len_passage = len_passage
self.max_num_articles = max_num_articles
self.include_article_title_in_passage = include_article_title_in_passage
self.include_text_snippet_in_example = include_text_snippet_in_example
# ex_index=5 deactivates printing inside convert_single_example.
self.ex_index = 5
# Parameters used in InputExample, not used in NHNet.
self.label = 0
self.guid = 0
self.num_generated_examples = 0
def read_crawled_articles(self, folder_path):
"""Reads crawled articles under folder_path."""
for path, _, files in os.walk(folder_path):
for name in files:
if not name.endswith(".json"):
continue
url, article = self._get_article_content_from_json(
os.path.join(path, name))
if not article.text_a:
continue
self.articles[RawDataProcessor.normalize_url(url)] = article
if len(self.articles) % 5000 == 0:
print("Number of articles loaded: %d\r" % len(self.articles), end="")
print()
return len(self.articles)
def generate_examples(self, input_file, output_files):
"""Loads story from input json file and exports examples in output_files."""
writers = []
story_partition = []
for output_file in output_files:
writers.append(tf.io.TFRecordWriter(output_file))
story_partition.append(list())
with tf.io.gfile.GFile(input_file, "r") as story_json_file:
stories = json.load(story_json_file)
writer_index = 0
for story in stories:
articles = []
for url in story["urls"]:
normalized_url = RawDataProcessor.normalize_url(url)
if normalized_url in self.articles:
articles.append(self.articles[normalized_url])
if not articles:
continue
story_partition[writer_index].append((story["label"], articles))
writer_index = (writer_index + 1) % len(writers)
lock = multiprocessing.Lock()
pool = multiprocessing.pool.ThreadPool(len(writers))
data = [(story_partition[i], writers[i], lock) for i in range(len(writers))]
pool.map(self._write_story_partition, data)
return len(stories), self.num_generated_examples
@classmethod
def normalize_url(cls, url):
"""Normalize url for better matching."""
url = urllib.parse.unquote(
urllib.parse.urlsplit(url)._replace(query=None).geturl())
output, part = [], None
for part in url.split("//"):
if part == "http:" or part == "https:":
continue
else:
output.append(part)
return "//".join(output)
def _get_article_content_from_json(self, file_path):
"""Returns (url, InputExample) keeping content extracted from file_path."""
with tf.io.gfile.GFile(file_path, "r") as article_json_file:
article = json.load(article_json_file)
if self.include_article_title_in_passage:
return article["url"], classifier_data_lib.InputExample(
guid=self.guid,
text_a=article["title"],
text_b=article["maintext"],
label=self.label)
else:
return article["url"], classifier_data_lib.InputExample(
guid=self.guid, text_a=article["maintext"], label=self.label)
def _write_story_partition(self, data):
"""Writes stories in a partition into file."""
for (story_headline, articles) in data[0]:
story_example = tf.train.Example(
features=tf.train.Features(
feature=self._get_single_story_features(story_headline,
articles)))
data[1].write(story_example.SerializeToString())
data[2].acquire()
try:
self.num_generated_examples += 1
if self.num_generated_examples % 1000 == 0:
print(
"Number of stories written: %d\r" % self.num_generated_examples,
end="")
finally:
data[2].release()
def _get_single_story_features(self, story_headline, articles):
"""Converts a list of articles to a tensorflow Example."""
def get_text_snippet(article):
if article.text_b:
return " [SEP] ".join([article.text_a, article.text_b])
else:
return article.text_a
story_features = collections.OrderedDict()
story_headline_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=classifier_data_lib.InputExample(
guid=self.guid, text_a=story_headline, label=self.label),
label_list=[self.label],
max_seq_length=self.len_title,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
story_headline_feature.label_id = story_headline
self._add_feature_with_suffix(
feature=story_headline_feature,
suffix="a",
story_features=story_features)
for (article_index, article) in enumerate(articles):
if article_index == self.max_num_articles:
break
article_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=article,
label_list=[self.label],
max_seq_length=self.len_passage,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
article_feature.label_id = get_text_snippet(article)
suffix = chr(ord("b") + article_index)
self._add_feature_with_suffix(
feature=article_feature, suffix=suffix, story_features=story_features)
# Adds empty features as placeholder.
for article_index in range(len(articles), self.max_num_articles):
suffix = chr(ord("b") + article_index)
empty_article = classifier_data_lib.InputExample(
guid=self.guid, text_a="", label=self.label)
empty_feature = classifier_data_lib.convert_single_example(
ex_index=self.ex_index,
example=empty_article,
label_list=[self.label],
max_seq_length=self.len_passage,
tokenizer=self.tokenizer)
if self.include_text_snippet_in_example:
empty_feature.label_id = ""
self._add_feature_with_suffix(
feature=empty_feature, suffix=suffix, story_features=story_features)
return story_features
def _add_feature_with_suffix(self, feature, suffix, story_features):
"""Appends suffix to feature names and fills in the corresponding values."""
def _create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
def _create_string_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
story_features["input_ids_%c" % suffix] = _create_int_feature(
feature.input_ids)
story_features["input_mask_%c" % suffix] = _create_int_feature(
feature.input_mask)
story_features["segment_ids_%c" % suffix] = _create_int_feature(
feature.segment_ids)
if self.include_text_snippet_in_example:
story_features["text_snippet_%c" % suffix] = _create_string_feature(
bytes(feature.label_id.encode()))
<!DOCTYPE html>
<meta charset="utf-8">
<title>Page Title 0</title>
{
"title": "title for 0",
"maintext": "text snippet for 0",
"url": "http://url_000.html"
}
<!DOCTYPE html>
<meta charset="utf-8">
<title>Page Title 1</title>
{
"title": "title for 1",
"maintext": "text snippet for 1",
"url": "url_001.html"
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment