Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
......@@ -63,6 +63,8 @@ def main(_):
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
......@@ -267,7 +267,8 @@ class ProgressiveTrainer(trainer_lib.Trainer):
step_interval = self.config.trainer.checkpoint_interval
else:
step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0:
if global_step_np % step_interval != 0 and (
global_step_np < self._config.trainer.train_steps):
logging.info('Not exporting checkpoints in global step: %d.',
global_step_np)
return
......
......@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
block_size: int = 64
type_vocab_size: int = 16
initializer_range: float = 0.02
embedding_size: Optional[int] = None
embedding_width: Optional[int] = None
@dataclasses.dataclass
......@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_sequence_length=encoder_cfg.max_position_embeddings,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size)
embedding_width=encoder_cfg.embedding_width)
if encoder_type == "xlnet":
return encoder_cls(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver library."""
import gc
import os
import time
from typing import Any, Mapping, Optional
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib
def _flatten_dict(xs):
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Example::
xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
flat_xs = flatten_dict(xs)
print(flat_xs)
# {
# ('foo',): 1,
# ('bar', 'a'): 2,
# }
Note that empty dictionaries are ignored and
will not be restored by `unflatten_dict`.
Args:
xs: a nested dictionary
Returns:
The flattened dictionary.
"""
assert isinstance(xs, dict), 'input is not a dict'
def _flatten(xs, prefix):
if not isinstance(xs, dict):
return {prefix: xs}
result = {}
for key, value in xs.items():
path = prefix + (key,)
result.update(_flatten(value, path))
return result
return _flatten(xs, ())
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert mode == 'continuous_train_and_eval', (
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'.format(mode))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
retry_times = 0
while not tf.io.gfile.isdir(params.task.init_checkpoint):
# Wait for the init_checkpoint directory to be created.
if retry_times >= 60:
raise ValueError(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.')
retry_times += 1
time.sleep(60)
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
# Replaces params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
if params.trainer.best_checkpoint_export_subdir:
best_ckpt_subdir = '{}_{}'.format(
params.trainer.best_checkpoint_export_subdir, global_step)
params_replaced = params.replace(
task={'init_checkpoint': pretrain_ckpt},
trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
else:
params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
params_replaced.lock()
logging.info('Running finetuning with params: %s', params_replaced)
with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks)
(_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=task,
eval_tasks=eval_tasks,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
else:
task = task_factory.get_task(
params_replaced.task, logging_dir=model_dir)
_, eval_metrics = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.name
else:
summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {}
for name, value in _flatten_dict(eval_metrics).items():
summaries[summary_grp + '/' + '-'.join(name)] = value
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc.collect()
if run_post_eval:
return eval_metrics
return {}
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,11 +18,15 @@ from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.nlp import train_ctl_continuous_finetune
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
......@@ -36,8 +39,8 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@parameterized.parameters(None, 1)
def testTrainCtl(self, pretrain_steps):
def testContinuousFinetune(self):
pretrain_steps = 1
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
......@@ -79,7 +82,7 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
model_dir=src_model_dir)
params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune(
eval_metrics = continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode,
params,
FLAGS.model_dir,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""
import collections
import math
import re
import sys
import unicodedata
import numpy as np
import tensorflow as tf
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(
chr(x)
for x in range(sys.maxunicode)
if unicodedata.category(chr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = tf.io.gfile.GFile(ref_filename).read().strip().splitlines()
hyp_lines = tf.io.gfile.GFile(hyp_filename).read().strip().splitlines()
return bleu_on_list(ref_lines, hyp_lines, case_sensitive)
def _get_ngrams_with_counter(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
precisions = [0] * max_order
smooth = 1.0
for i in range(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False):
"""Compute BLEU for two list of strings (reference and hypothesis)."""
if len(ref_lines) != len(hyp_lines):
raise ValueError(
"Reference and translation files have different number of "
"lines (%d VS %d). If training only a few steps (100-200), the "
"translation may be empty." % (len(ref_lines), len(hyp_lines)))
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens) * 100
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
import tempfile
import tensorflow as tf
from official.nlp.metrics import bleu
class ComputeBleuTest(tf.test.TestCase):
def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.io.gfile.GFile(temp_file.name, "w") as w:
w.write(text)
return temp_file.name
def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)
def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)
def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)
def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = bleu.bleu_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
def test_bleu_list(self):
ref = ["test 1 two 3", "more tests!"]
hyp = ["test 1 two 3", "More tests!"]
uncased_score = bleu.bleu_on_list(ref, hyp, False)
cased_score = bleu.bleu_on_list(ref, hyp, True)
self.assertEqual(uncased_score, 100)
self.assertLess(cased_score, 100)
if __name__ == "__main__":
tf.test.main()
......@@ -26,6 +26,7 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionBias
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention
from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention
......
......@@ -14,13 +14,15 @@
# ==============================================================================
"""Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
import math
from typing import Optional
import tensorflow as tf
from official.modeling import tf_utils
Initializer = tf.keras.initializers.Initializer
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf.keras.layers.Layer):
......@@ -38,9 +40,9 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"""
def __init__(self,
hidden_size,
min_timescale=1.0,
max_timescale=1.0e4,
hidden_size: int,
min_timescale: float = 1.0,
max_timescale: float = 1.0e4,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
......@@ -50,7 +52,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(RelativePositionEmbedding, self).__init__(**kwargs)
super().__init__(**kwargs)
self._hidden_size = hidden_size
self._min_timescale = min_timescale
self._max_timescale = max_timescale
......@@ -101,3 +103,135 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return position_embeddings
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
else:
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.cast(n, tf.float32) / max_exact) /
math.log(max_distance / max_exact) * (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionBias(tf.keras.layers.Layer):
"""Relative position embedding via per-head bias in T5 style.
Reference implementation in MeshTF:
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000
This layer implements the relative position bias used in "Exploring the Limits
of Transfer Learning with a Unified Text-to-Text Transformer"
(https://arxiv.org/abs/1910.10683)
"""
def __init__(self,
num_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
bidirectional: bool = True,
embeddings_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.bidirectional = bidirectional
self.relative_attention_max_distance = relative_attention_max_distance
if embeddings_initializer:
self._embed_init = embeddings_initializer
else:
self._embed_init = tf.keras.initializers.TruncatedNormal(stddev=1.0)
with tf.name_scope(self.name):
self._relative_attention_bias = self.add_weight(
"rel_embedding",
shape=[self.relative_attention_num_buckets, self.num_heads],
initializer=self._embed_init,
dtype=self.dtype,
trainable=True)
def get_config(self):
config = {
"num_heads":
self.num_heads,
"relative_attention_num_buckets":
self.relative_attention_num_buckets,
"relative_attention_max_distance":
self.relative_attention_max_distance,
"bidirectional":
self.bidirectional,
"embeddings_initializer":
tf.keras.initializers.serialize(self._embed_init),
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, query: tf.Tensor, key: tf.Tensor):
"""Implements the forward pass.
Args:
query: query input tensor shape [batch, query length, hidden size].
key: key input tensor shape [batch, key length, hidden size].
Returns:
A tensor in shape of [batch, heads, query length, key length].
"""
batch_size, qlen = tf_utils.get_shape_list(query)[:2]
klen = tf_utils.get_shape_list(key)[1]
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position
rp_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance)
values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]),
axis=0) # shape (1, num_heads, qlen, klen)
values = tf.tile(values, [batch_size, 1, 1, 1])
return values
......@@ -14,6 +14,8 @@
# ==============================================================================
"""Tests for Keras-based positional embedding layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -55,5 +57,32 @@ class RelativePositionEmbeddingLayerTest(keras_parameterized.TestCase):
self.assertAllEqual(output_tensor, expected_output_tensor)
@keras_parameterized.run_all_keras_modes
class RelativePositionBiasTest(keras_parameterized.TestCase):
@parameterized.named_parameters(("bidirectional", True),
("unidirectional", False))
def test_relative_position_bias(self, bidirectional):
query = tf.zeros((4, 4, 2))
key = tf.zeros((4, 2, 2))
l = position_embedding.RelativePositionBias(
num_heads=3,
bidirectional=bidirectional,
name="foo")
self.assertEqual(l(query, key).shape, (4, 3, 4, 2))
self.assertLen(l.trainable_variables, 1)
self.assertEqual(l.trainable_variables[0].name, "foo/rel_embedding:0")
def test_relative_position_bucket(self):
context_position = tf.range(3)[:, None]
memory_position = tf.range(2)[None, :]
relative_position = memory_position - context_position
outputs = position_embedding._relative_position_bucket(relative_position)
self.assertAllEqual(outputs.numpy(), np.array([[0, 17], [1, 0], [2, 1]]))
outputs = position_embedding._relative_position_bucket(
relative_position, bidirectional=False)
self.assertAllEqual(outputs.numpy(), np.array([[0, 0], [1, 0], [2, 1]]))
if __name__ == "__main__":
tf.test.main()
......@@ -133,7 +133,16 @@ class BertTokenizer(tf.keras.layers.Layer):
_check_if_tf_text_installed()
self.tokenize_with_offsets = tokenize_with_offsets
self._vocab_table = self._create_vocab_table(vocab_file)
# TODO(b/177326279): Stop storing the vocab table initializer as an
# attribute when https://github.com/tensorflow/tensorflow/issues/46456
# has been fixed in the TensorFlow versions of the TF Hub users that load
# a SavedModel created from this layer. Due to that issue, loading such a
# SavedModel forgets to add .vocab_table._initializer as a trackable
# dependency of .vocab_table, so that saving it again to a second SavedModel
# (e.g., the final model built using TF Hub) does not properly track
# the ._vocab_table._initializer._filename as an Asset.
self._vocab_table, self._vocab_initializer_donotuse = (
self._create_vocab_table_and_initializer(vocab_file))
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
......@@ -144,12 +153,13 @@ class BertTokenizer(tf.keras.layers.Layer):
def vocab_size(self):
return self._vocab_table.size()
def _create_vocab_table(self, vocab_file):
def _create_vocab_table_and_initializer(self, vocab_file):
vocab_initializer = tf.lookup.TextFileInitializer(
vocab_file,
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
return tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
return vocab_table, vocab_initializer
def call(self, inputs: tf.Tensor):
"""Calls text.BertTokenizer on inputs.
......@@ -211,6 +221,7 @@ class BertTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
"""
return self._special_tokens_dict
......@@ -223,6 +234,7 @@ class BertTokenizer(tf.keras.layers.Layer):
if tf.executing_eagerly():
special_token_ids = vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
vocab_size = vocab_table.size()
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
......@@ -230,16 +242,21 @@ class BertTokenizer(tf.keras.layers.Layer):
"Non-eager init context; computing "
"BertTokenizer's special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default():
local_vocab_table = self._create_vocab_table(vocab_file)
local_vocab_table, _ = self._create_vocab_table_and_initializer(
vocab_file)
special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string))
vocab_size_tensor = local_vocab_table.size()
init_ops = [tf.compat.v1.initialize_all_tables()]
with tf.compat.v1.Session() as sess:
sess.run(init_ops)
special_token_ids = sess.run(special_token_ids_tensor)
result = dict()
special_token_ids, vocab_size = sess.run(
[special_token_ids_tensor, vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for k, v in zip(special_tokens, special_token_ids):
v = int(v) # Numpy to Python.
v = int(v)
if v >= 0:
result[k] = v
else:
......@@ -414,6 +431,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up from "<pad>"
* mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
"""
return self._special_tokens_dict
......@@ -428,6 +446,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
special_token_ids = self._tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens = self._tokenizer.id_to_string(special_token_ids)
vocab_size = self._tokenizer.vocab_size()
else:
# A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
......@@ -440,15 +459,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens_tensor = local_tokenizer.id_to_string(
special_token_ids_tensor)
vocab_size_tensor = local_tokenizer.vocab_size()
with tf.compat.v1.Session() as sess:
special_token_ids, inverse_tokens = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor])
result = dict()
special_token_ids, inverse_tokens, vocab_size = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor,
vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for name, token_id, inverse_token in zip(special_tokens,
special_token_ids,
inverse_tokens):
if special_tokens[name] == inverse_token:
result[name] = int(token_id) # Numpy to Python.
result[name] = int(token_id)
else:
logging.warning(
"Could not find %s as token \"%s\" in sentencepiece model, "
......
......@@ -130,7 +130,8 @@ class BertTokenizerTest(tf.test.TestCase):
dict(padding_id=1,
start_of_sequence_id=3,
end_of_segment_id=4,
mask_id=5))
mask_id=5,
vocab_size=7))
def test_special_tokens_partial(self):
vocab_file = self._make_vocab_file(
......@@ -140,7 +141,8 @@ class BertTokenizerTest(tf.test.TestCase):
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0,
start_of_sequence_id=1,
end_of_segment_id=2)) # No mask_id,
end_of_segment_id=2,
vocab_size=3)) # No mask_id,
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
......@@ -252,7 +254,8 @@ class SentencepieceTokenizerTest(tf.test.TestCase):
dict(padding_id=0,
start_of_sequence_id=2,
end_of_segment_id=3,
mask_id=4))
mask_id=4,
vocab_size=16))
def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context."""
......
......@@ -160,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
......@@ -243,8 +243,8 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
......
......@@ -82,7 +82,7 @@ Next, we can run the following data preprocess script which may take a few hours
```shell
# Recall that we use DATA_FOLDER=/path/to/downloaded_dataset.
$ python3 raw_data_preprocess.py \
$ python3 raw_data_process.py \
-crawled_articles=/tmp/nhnet \
-vocab=/path/to/bert_checkpoint/vocab.txt \
-do_lower_case=True \
......
......@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model):
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
max_position_embeddings: The maximum length of position embeddings that this
encoder can consume. If None, max_position_embeddings uses the value from
sequence length. This determines the variable shape for positional
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
......@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model):
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN,
max_position_embeddings=attention.MAX_SEQ_LEN,
type_vocab_size=16,
intermediate_size=3072,
block_size=64,
......@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model):
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'max_position_embeddings': max_position_embeddings,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'block_size': block_size,
......@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
max_length=max_position_embeddings,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
......@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model):
from_block_size=block_size,
to_block_size=block_size,
num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length,
max_rand_mask_length=max_position_embeddings,
seed=i),
dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate,
......
......@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......
......@@ -146,7 +146,7 @@ def read_model_config(encoder,
return encoder_config
@gin.configurable(blacklist=[
@gin.configurable(denylist=[
'model',
'strategy',
'train_dataset',
......
......@@ -28,8 +28,8 @@ from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.hyperparams import base_config
from official.nlp.data import data_loader_factory
from official.nlp.metrics import bleu
from official.nlp.modeling import models
from official.nlp.transformer import compute_bleu
def _pad_tensors_to_same_length(x, y):
......@@ -364,6 +364,6 @@ class TranslationTask(base_task.Task):
src, translation, self._references[u_id])
sacrebleu_score = sacrebleu.corpus_bleu(
translations, [self._references]).score
bleu_score = compute_bleu.bleu_on_list(self._references, translations)
bleu_score = bleu.bleu_on_list(self._references, translations)
return {"sacrebleu_score": sacrebleu_score,
"bleu_score": bleu_score}
......@@ -64,6 +64,8 @@ def main(_):
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment