Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
...@@ -63,6 +63,8 @@ def main(_): ...@@ -63,6 +63,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
...@@ -267,7 +267,8 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -267,7 +267,8 @@ class ProgressiveTrainer(trainer_lib.Trainer):
step_interval = self.config.trainer.checkpoint_interval step_interval = self.config.trainer.checkpoint_interval
else: else:
step_interval = self.config.trainer.export_checkpoint_interval step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0: if global_step_np % step_interval != 0 and (
global_step_np < self._config.trainer.train_steps):
logging.info('Not exporting checkpoints in global step: %d.', logging.info('Not exporting checkpoints in global step: %d.',
global_step_np) global_step_np)
return return
......
...@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
block_size: int = 64 block_size: int = 64
type_vocab_size: int = 16 type_vocab_size: int = 16
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None embedding_width: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig, ...@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks, num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size, block_size=encoder_cfg.block_size,
max_sequence_length=encoder_cfg.max_position_embeddings, max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size) embedding_width=encoder_cfg.embedding_width)
if encoder_type == "xlnet": if encoder_type == "xlnet":
return encoder_cls( return encoder_cls(
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver library."""
import gc
import os
import time
from typing import Any, Mapping, Optional
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib as multitask_train_lib
def _flatten_dict(xs):
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Example::
xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
flat_xs = flatten_dict(xs)
print(flat_xs)
# {
# ('foo',): 1,
# ('bar', 'a'): 2,
# }
Note that empty dictionaries are ignored and
will not be restored by `unflatten_dict`.
Args:
xs: a nested dictionary
Returns:
The flattened dictionary.
"""
assert isinstance(xs, dict), 'input is not a dict'
def _flatten(xs, prefix):
if not isinstance(xs, dict):
return {prefix: xs}
result = {}
for key, value in xs.items():
path = prefix + (key,)
result.update(_flatten(value, path))
return result
return _flatten(xs, ())
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert mode == 'continuous_train_and_eval', (
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'.format(mode))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
retry_times = 0
while not tf.io.gfile.isdir(params.task.init_checkpoint):
# Wait for the init_checkpoint directory to be created.
if retry_times >= 60:
raise ValueError(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.')
retry_times += 1
time.sleep(60)
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
# Replaces params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
if params.trainer.best_checkpoint_export_subdir:
best_ckpt_subdir = '{}_{}'.format(
params.trainer.best_checkpoint_export_subdir, global_step)
params_replaced = params.replace(
task={'init_checkpoint': pretrain_ckpt},
trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
else:
params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
params_replaced.lock()
logging.info('Running finetuning with params: %s', params_replaced)
with distribution_strategy.scope():
if isinstance(params, configs.MultiEvalExperimentConfig):
task = task_factory.get_task(params_replaced.task)
eval_tasks = multitask.MultiTask.from_config(params_replaced.eval_tasks)
(_,
eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=task,
eval_tasks=eval_tasks,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
else:
task = task_factory.get_task(
params_replaced.task, logging_dir=model_dir)
_, eval_metrics = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.name
else:
summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {}
for name, value in _flatten_dict(eval_metrics).items():
summaries[summary_grp + '/' + '-'.join(name)] = value
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc.collect()
if run_post_eval:
return eval_metrics
return {}
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,11 +18,15 @@ from absl import flags ...@@ -19,11 +18,15 @@ from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.nlp import train_ctl_continuous_finetune from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -36,8 +39,8 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase): ...@@ -36,8 +39,8 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir') self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
@parameterized.parameters(None, 1) def testContinuousFinetune(self):
def testTrainCtl(self, pretrain_steps): pretrain_steps = 1
src_model_dir = self.get_temp_dir() src_model_dir = self.get_temp_dir()
flags_dict = dict( flags_dict = dict(
experiment='mock', experiment='mock',
...@@ -79,7 +82,7 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,7 +82,7 @@ class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
model_dir=src_model_dir) model_dir=src_model_dir)
params = train_utils.parse_configuration(FLAGS) params = train_utils.parse_configuration(FLAGS)
eval_metrics = train_ctl_continuous_finetune.run_continuous_finetune( eval_metrics = continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode, FLAGS.mode,
params, params,
FLAGS.model_dir, FLAGS.model_dir,
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to compute official BLEU score.
Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
"""
import collections
import math
import re
import sys
import unicodedata
import numpy as np
import tensorflow as tf
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(
chr(x)
for x in range(sys.maxunicode)
if unicodedata.category(chr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = tf.io.gfile.GFile(ref_filename).read().strip().splitlines()
hyp_lines = tf.io.gfile.GFile(hyp_filename).read().strip().splitlines()
return bleu_on_list(ref_lines, hyp_lines, case_sensitive)
def _get_ngrams_with_counter(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
precisions = [0] * max_order
smooth = 1.0
for i in range(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False):
"""Compute BLEU for two list of strings (reference and hypothesis)."""
if len(ref_lines) != len(hyp_lines):
raise ValueError(
"Reference and translation files have different number of "
"lines (%d VS %d). If training only a few steps (100-200), the "
"translation may be empty." % (len(ref_lines), len(hyp_lines)))
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens) * 100
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test functions in compute_blue.py."""
import tempfile
import tensorflow as tf
from official.nlp.metrics import bleu
class ComputeBleuTest(tf.test.TestCase):
def _create_temp_file(self, text):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.io.gfile.GFile(temp_file.name, "w") as w:
w.write(text)
return temp_file.name
def test_bleu_same(self):
ref = self._create_temp_file("test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nmore tests!")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertEqual(100, cased_score)
def test_bleu_same_different_case(self):
ref = self._create_temp_file("Test 1 two 3\nmore tests!")
hyp = self._create_temp_file("test 1 two 3\nMore tests!")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertEqual(100, uncased_score)
self.assertLess(cased_score, 100)
def test_bleu_different(self):
ref = self._create_temp_file("Testing\nmore tests!")
hyp = self._create_temp_file("Dog\nCat")
uncased_score = bleu.bleu_wrapper(ref, hyp, False)
cased_score = bleu.bleu_wrapper(ref, hyp, True)
self.assertLess(uncased_score, 100)
self.assertLess(cased_score, 100)
def test_bleu_tokenize(self):
s = "Test0, 1 two, 3"
tokenized = bleu.bleu_tokenize(s)
self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
def test_bleu_list(self):
ref = ["test 1 two 3", "more tests!"]
hyp = ["test 1 two 3", "More tests!"]
uncased_score = bleu.bleu_on_list(ref, hyp, False)
cased_score = bleu.bleu_on_list(ref, hyp, True)
self.assertEqual(uncased_score, 100)
self.assertLess(cased_score, 100)
if __name__ == "__main__":
tf.test.main()
...@@ -26,6 +26,7 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM ...@@ -26,6 +26,7 @@ from official.nlp.modeling.layers.mobile_bert_layers import MobileBertMaskedLM
from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer from official.nlp.modeling.layers.mobile_bert_layers import MobileBertTransformer
from official.nlp.modeling.layers.multi_channel_attention import * from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import RelativePositionBias
from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding from official.nlp.modeling.layers.position_embedding import RelativePositionEmbedding
from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention from official.nlp.modeling.layers.relative_attention import MultiHeadRelativeAttention
from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention from official.nlp.modeling.layers.relative_attention import TwoStreamRelativeAttention
......
...@@ -14,13 +14,15 @@ ...@@ -14,13 +14,15 @@
# ============================================================================== # ==============================================================================
"""Keras-based positional embedding layer.""" """Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import math import math
from typing import Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
Initializer = tf.keras.initializers.Initializer
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf.keras.layers.Layer): class RelativePositionEmbedding(tf.keras.layers.Layer):
...@@ -38,9 +40,9 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -38,9 +40,9 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size: int,
min_timescale=1.0, min_timescale: float = 1.0,
max_timescale=1.0e4, max_timescale: float = 1.0e4,
**kwargs): **kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras # We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32. # usually uses to infer the dtype) will always be int32.
...@@ -50,7 +52,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -50,7 +52,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
if "dtype" not in kwargs: if "dtype" not in kwargs:
kwargs["dtype"] = "float32" kwargs["dtype"] = "float32"
super(RelativePositionEmbedding, self).__init__(**kwargs) super().__init__(**kwargs)
self._hidden_size = hidden_size self._hidden_size = hidden_size
self._min_timescale = min_timescale self._min_timescale = min_timescale
self._max_timescale = max_timescale self._max_timescale = max_timescale
...@@ -101,3 +103,135 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -101,3 +103,135 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return position_embeddings return position_embeddings
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If bidirectional=False, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
else:
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.cast(n, tf.float32) / max_exact) /
math.log(max_distance / max_exact) * (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionBias(tf.keras.layers.Layer):
"""Relative position embedding via per-head bias in T5 style.
Reference implementation in MeshTF:
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000
This layer implements the relative position bias used in "Exploring the Limits
of Transfer Learning with a Unified Text-to-Text Transformer"
(https://arxiv.org/abs/1910.10683)
"""
def __init__(self,
num_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
bidirectional: bool = True,
embeddings_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.bidirectional = bidirectional
self.relative_attention_max_distance = relative_attention_max_distance
if embeddings_initializer:
self._embed_init = embeddings_initializer
else:
self._embed_init = tf.keras.initializers.TruncatedNormal(stddev=1.0)
with tf.name_scope(self.name):
self._relative_attention_bias = self.add_weight(
"rel_embedding",
shape=[self.relative_attention_num_buckets, self.num_heads],
initializer=self._embed_init,
dtype=self.dtype,
trainable=True)
def get_config(self):
config = {
"num_heads":
self.num_heads,
"relative_attention_num_buckets":
self.relative_attention_num_buckets,
"relative_attention_max_distance":
self.relative_attention_max_distance,
"bidirectional":
self.bidirectional,
"embeddings_initializer":
tf.keras.initializers.serialize(self._embed_init),
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, query: tf.Tensor, key: tf.Tensor):
"""Implements the forward pass.
Args:
query: query input tensor shape [batch, query length, hidden size].
key: key input tensor shape [batch, key length, hidden size].
Returns:
A tensor in shape of [batch, heads, query length, key length].
"""
batch_size, qlen = tf_utils.get_shape_list(query)[:2]
klen = tf_utils.get_shape_list(key)[1]
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position
rp_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance)
values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]),
axis=0) # shape (1, num_heads, qlen, klen)
values = tf.tile(values, [batch_size, 1, 1, 1])
return values
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Tests for Keras-based positional embedding layer.""" """Tests for Keras-based positional embedding layer."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -55,5 +57,32 @@ class RelativePositionEmbeddingLayerTest(keras_parameterized.TestCase): ...@@ -55,5 +57,32 @@ class RelativePositionEmbeddingLayerTest(keras_parameterized.TestCase):
self.assertAllEqual(output_tensor, expected_output_tensor) self.assertAllEqual(output_tensor, expected_output_tensor)
@keras_parameterized.run_all_keras_modes
class RelativePositionBiasTest(keras_parameterized.TestCase):
@parameterized.named_parameters(("bidirectional", True),
("unidirectional", False))
def test_relative_position_bias(self, bidirectional):
query = tf.zeros((4, 4, 2))
key = tf.zeros((4, 2, 2))
l = position_embedding.RelativePositionBias(
num_heads=3,
bidirectional=bidirectional,
name="foo")
self.assertEqual(l(query, key).shape, (4, 3, 4, 2))
self.assertLen(l.trainable_variables, 1)
self.assertEqual(l.trainable_variables[0].name, "foo/rel_embedding:0")
def test_relative_position_bucket(self):
context_position = tf.range(3)[:, None]
memory_position = tf.range(2)[None, :]
relative_position = memory_position - context_position
outputs = position_embedding._relative_position_bucket(relative_position)
self.assertAllEqual(outputs.numpy(), np.array([[0, 17], [1, 0], [2, 1]]))
outputs = position_embedding._relative_position_bucket(
relative_position, bidirectional=False)
self.assertAllEqual(outputs.numpy(), np.array([[0, 0], [1, 0], [2, 1]]))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -133,7 +133,16 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -133,7 +133,16 @@ class BertTokenizer(tf.keras.layers.Layer):
_check_if_tf_text_installed() _check_if_tf_text_installed()
self.tokenize_with_offsets = tokenize_with_offsets self.tokenize_with_offsets = tokenize_with_offsets
self._vocab_table = self._create_vocab_table(vocab_file) # TODO(b/177326279): Stop storing the vocab table initializer as an
# attribute when https://github.com/tensorflow/tensorflow/issues/46456
# has been fixed in the TensorFlow versions of the TF Hub users that load
# a SavedModel created from this layer. Due to that issue, loading such a
# SavedModel forgets to add .vocab_table._initializer as a trackable
# dependency of .vocab_table, so that saving it again to a second SavedModel
# (e.g., the final model built using TF Hub) does not properly track
# the ._vocab_table._initializer._filename as an Asset.
self._vocab_table, self._vocab_initializer_donotuse = (
self._create_vocab_table_and_initializer(vocab_file))
self._special_tokens_dict = self._create_special_tokens_dict( self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file) self._vocab_table, vocab_file)
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -144,12 +153,13 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -144,12 +153,13 @@ class BertTokenizer(tf.keras.layers.Layer):
def vocab_size(self): def vocab_size(self):
return self._vocab_table.size() return self._vocab_table.size()
def _create_vocab_table(self, vocab_file): def _create_vocab_table_and_initializer(self, vocab_file):
vocab_initializer = tf.lookup.TextFileInitializer( vocab_initializer = tf.lookup.TextFileInitializer(
vocab_file, vocab_file,
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
return tf.lookup.StaticHashTable(vocab_initializer, default_value=-1) vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
return vocab_table, vocab_initializer
def call(self, inputs: tf.Tensor): def call(self, inputs: tf.Tensor):
"""Calls text.BertTokenizer on inputs. """Calls text.BertTokenizer on inputs.
...@@ -211,6 +221,7 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -211,6 +221,7 @@ class BertTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]" * end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up form "[PAD]" * padding_id: looked up form "[PAD]"
* mask_id: looked up from "[MASK]" * mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
""" """
return self._special_tokens_dict return self._special_tokens_dict
...@@ -223,6 +234,7 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -223,6 +234,7 @@ class BertTokenizer(tf.keras.layers.Layer):
if tf.executing_eagerly(): if tf.executing_eagerly():
special_token_ids = vocab_table.lookup( special_token_ids = vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
vocab_size = vocab_table.size()
else: else:
# A blast from the past: non-eager init context while building Model. # A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior(). # This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
...@@ -230,16 +242,21 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -230,16 +242,21 @@ class BertTokenizer(tf.keras.layers.Layer):
"Non-eager init context; computing " "Non-eager init context; computing "
"BertTokenizer's special_tokens_dict in tf.compat.v1.Session") "BertTokenizer's special_tokens_dict in tf.compat.v1.Session")
with tf.Graph().as_default(): with tf.Graph().as_default():
local_vocab_table = self._create_vocab_table(vocab_file) local_vocab_table, _ = self._create_vocab_table_and_initializer(
vocab_file)
special_token_ids_tensor = local_vocab_table.lookup( special_token_ids_tensor = local_vocab_table.lookup(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
vocab_size_tensor = local_vocab_table.size()
init_ops = [tf.compat.v1.initialize_all_tables()] init_ops = [tf.compat.v1.initialize_all_tables()]
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
sess.run(init_ops) sess.run(init_ops)
special_token_ids = sess.run(special_token_ids_tensor) special_token_ids, vocab_size = sess.run(
result = dict() [special_token_ids_tensor, vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for k, v in zip(special_tokens, special_token_ids): for k, v in zip(special_tokens, special_token_ids):
v = int(v) # Numpy to Python. v = int(v)
if v >= 0: if v >= 0:
result[k] = v result[k] = v
else: else:
...@@ -414,6 +431,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -414,6 +431,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
* end_of_segment_id: looked up from "[SEP]" * end_of_segment_id: looked up from "[SEP]"
* padding_id: looked up from "<pad>" * padding_id: looked up from "<pad>"
* mask_id: looked up from "[MASK]" * mask_id: looked up from "[MASK]"
* vocab_size: one past the largest token id used
""" """
return self._special_tokens_dict return self._special_tokens_dict
...@@ -428,6 +446,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -428,6 +446,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
special_token_ids = self._tokenizer.string_to_id( special_token_ids = self._tokenizer.string_to_id(
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens = self._tokenizer.id_to_string(special_token_ids) inverse_tokens = self._tokenizer.id_to_string(special_token_ids)
vocab_size = self._tokenizer.vocab_size()
else: else:
# A blast from the past: non-eager init context while building Model. # A blast from the past: non-eager init context while building Model.
# This can happen with Estimator or tf.compat.v1.disable_v2_behavior(). # This can happen with Estimator or tf.compat.v1.disable_v2_behavior().
...@@ -440,15 +459,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -440,15 +459,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
tf.constant(list(special_tokens.values()), tf.string)) tf.constant(list(special_tokens.values()), tf.string))
inverse_tokens_tensor = local_tokenizer.id_to_string( inverse_tokens_tensor = local_tokenizer.id_to_string(
special_token_ids_tensor) special_token_ids_tensor)
vocab_size_tensor = local_tokenizer.vocab_size()
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
special_token_ids, inverse_tokens = sess.run( special_token_ids, inverse_tokens, vocab_size = sess.run(
[special_token_ids_tensor, inverse_tokens_tensor]) [special_token_ids_tensor, inverse_tokens_tensor,
result = dict() vocab_size_tensor])
result = dict(
vocab_size=int(vocab_size) # Numpy to Python.
)
for name, token_id, inverse_token in zip(special_tokens, for name, token_id, inverse_token in zip(special_tokens,
special_token_ids, special_token_ids,
inverse_tokens): inverse_tokens):
if special_tokens[name] == inverse_token: if special_tokens[name] == inverse_token:
result[name] = int(token_id) # Numpy to Python. result[name] = int(token_id)
else: else:
logging.warning( logging.warning(
"Could not find %s as token \"%s\" in sentencepiece model, " "Could not find %s as token \"%s\" in sentencepiece model, "
......
...@@ -130,7 +130,8 @@ class BertTokenizerTest(tf.test.TestCase): ...@@ -130,7 +130,8 @@ class BertTokenizerTest(tf.test.TestCase):
dict(padding_id=1, dict(padding_id=1,
start_of_sequence_id=3, start_of_sequence_id=3,
end_of_segment_id=4, end_of_segment_id=4,
mask_id=5)) mask_id=5,
vocab_size=7))
def test_special_tokens_partial(self): def test_special_tokens_partial(self):
vocab_file = self._make_vocab_file( vocab_file = self._make_vocab_file(
...@@ -140,7 +141,8 @@ class BertTokenizerTest(tf.test.TestCase): ...@@ -140,7 +141,8 @@ class BertTokenizerTest(tf.test.TestCase):
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(), self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
dict(padding_id=0, dict(padding_id=0,
start_of_sequence_id=1, start_of_sequence_id=1,
end_of_segment_id=2)) # No mask_id, end_of_segment_id=2,
vocab_size=3)) # No mask_id,
def test_special_tokens_in_estimator(self): def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
...@@ -252,7 +254,8 @@ class SentencepieceTokenizerTest(tf.test.TestCase): ...@@ -252,7 +254,8 @@ class SentencepieceTokenizerTest(tf.test.TestCase):
dict(padding_id=0, dict(padding_id=0,
start_of_sequence_id=2, start_of_sequence_id=2,
end_of_segment_id=3, end_of_segment_id=3,
mask_id=4)) mask_id=4,
vocab_size=16))
def test_special_tokens_in_estimator(self): def test_special_tokens_in_estimator(self):
"""Tests getting special tokens without an Eager init context.""" """Tests getting special tokens without an Eager init context."""
......
...@@ -160,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -160,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
embedded_inputs = self.embedding_lookup(sources) embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype) tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype) embedded_inputs = tf.cast(embedded_inputs, self._dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation. # Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2) input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast( attention_mask = tf.cast(
...@@ -243,8 +243,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -243,8 +243,8 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_inputs = self.embedding_lookup(targets) decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype) tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype) decoder_inputs = tf.cast(decoder_inputs, self._dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
......
...@@ -82,7 +82,7 @@ Next, we can run the following data preprocess script which may take a few hours ...@@ -82,7 +82,7 @@ Next, we can run the following data preprocess script which may take a few hours
```shell ```shell
# Recall that we use DATA_FOLDER=/path/to/downloaded_dataset. # Recall that we use DATA_FOLDER=/path/to/downloaded_dataset.
$ python3 raw_data_preprocess.py \ $ python3 raw_data_process.py \
-crawled_articles=/tmp/nhnet \ -crawled_articles=/tmp/nhnet \
-vocab=/path/to/bert_checkpoint/vocab.txt \ -vocab=/path/to/bert_checkpoint/vocab.txt \
-do_lower_case=True \ -do_lower_case=True \
......
...@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can max_position_embeddings: The maximum length of position embeddings that this
consume. If None, max_sequence_length uses the value from sequence length. encoder can consume. If None, max_position_embeddings uses the value from
This determines the variable shape for positional embeddings. sequence length. This determines the variable shape for positional
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take. type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers. intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers. activation: The activation to use for the transformer layers.
...@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN, max_position_embeddings=attention.MAX_SEQ_LEN,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
block_size=64, block_size=64,
...@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model):
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length, 'max_position_embeddings': max_position_embeddings,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
'block_size': block_size, 'block_size': block_size,
...@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding( self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
max_length=max_sequence_length, max_length=max_position_embeddings,
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding( self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
...@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model):
from_block_size=block_size, from_block_size=block_size,
to_block_size=block_size, to_block_size=block_size,
num_rand_blocks=num_rand_blocks, num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length, max_rand_mask_length=max_position_embeddings,
seed=i), seed=i),
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate,
......
...@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase): ...@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
network = encoder.BigBirdEncoder( network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096) num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
...@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase): ...@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
network = encoder.BigBirdEncoder( network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096) num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......
...@@ -146,7 +146,7 @@ def read_model_config(encoder, ...@@ -146,7 +146,7 @@ def read_model_config(encoder,
return encoder_config return encoder_config
@gin.configurable(blacklist=[ @gin.configurable(denylist=[
'model', 'model',
'strategy', 'strategy',
'train_dataset', 'train_dataset',
......
...@@ -28,8 +28,8 @@ from official.core import config_definitions as cfg ...@@ -28,8 +28,8 @@ from official.core import config_definitions as cfg
from official.core import task_factory from official.core import task_factory
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
from official.nlp.metrics import bleu
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.transformer import compute_bleu
def _pad_tensors_to_same_length(x, y): def _pad_tensors_to_same_length(x, y):
...@@ -364,6 +364,6 @@ class TranslationTask(base_task.Task): ...@@ -364,6 +364,6 @@ class TranslationTask(base_task.Task):
src, translation, self._references[u_id]) src, translation, self._references[u_id])
sacrebleu_score = sacrebleu.corpus_bleu( sacrebleu_score = sacrebleu.corpus_bleu(
translations, [self._references]).score translations, [self._references]).score
bleu_score = compute_bleu.bleu_on_list(self._references, translations) bleu_score = bleu.bleu_on_list(self._references, translations)
return {"sacrebleu_score": sacrebleu_score, return {"sacrebleu_score": sacrebleu_score,
"bleu_score": bleu_score} "bleu_score": bleu_score}
...@@ -64,6 +64,8 @@ def main(_): ...@@ -64,6 +64,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment