Commit 9485aa1d authored by qianyj's avatar qianyj
Browse files

Update code to v2.8.0

parents 89cfa348 f5fc733a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for calculating loss, accuracy, and other model metrics.
Metrics:
- Padded loss, accuracy, and negative log perplexity. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
- BLEU approximation. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
- ROUGE score. Source:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow.compat.v1 as tf
def _pad_tensors_to_same_length(x, y):
"""Pad x and y so that the results have the same length (second dimension)."""
with tf.name_scope("pad_to_same_length"):
x_length = tf.shape(x)[1]
y_length = tf.shape(y)[1]
max_length = tf.maximum(x_length, y_length)
x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
return x, y
def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
"""Calculate cross entropy loss while ignoring padding.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch_size, length_labels]
smoothing: Label smoothing constant, used to determine the on and off values
vocab_size: int size of the vocabulary
Returns:
Returns the cross entropy loss and weight tensors: float32 tensors with
shape [batch_size, max(length_logits, length_labels)]
"""
with tf.name_scope("loss", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
# Calculate smoothing cross entropy
with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
confidence = 1.0 - smoothing
low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
soft_targets = tf.one_hot(
tf.cast(labels, tf.int32),
depth=vocab_size,
on_value=confidence,
off_value=low_confidence)
xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=soft_targets)
# Calculate the best (lowest) possible value of cross entropy, and
# subtract from the cross entropy loss.
normalizing_constant = -(
confidence * tf.log(confidence) + tf.cast(vocab_size - 1, tf.float32)
* low_confidence * tf.log(low_confidence + 1e-20))
xentropy -= normalizing_constant
weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
return xentropy * weights, weights
def _convert_to_eval_metric(metric_fn):
"""Wrap a metric fn that returns scores and weights as an eval metric fn.
The input metric_fn returns values for the current batch. The wrapper
aggregates the return values collected over all of the batches evaluated.
Args:
metric_fn: function that returns scores and weights for the current batch's
logits and predicted labels.
Returns:
function that aggregates the scores and weights from metric_fn.
"""
def problem_metric_fn(*args):
"""Returns an aggregation of the metric_fn's returned values."""
(scores, weights) = metric_fn(*args)
# The tf.metrics.mean function assures correct aggregation.
return tf.metrics.mean(scores, weights)
return problem_metric_fn
def get_eval_metrics(logits, labels, params):
"""Return dictionary of model evaluation metrics."""
metrics = {
"accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels),
"accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)(
logits, labels),
"accuracy_per_sequence": _convert_to_eval_metric(
padded_sequence_accuracy)(logits, labels),
"neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
logits, labels, params["vocab_size"]),
}
if not params["use_tpu"]:
# TPU does not support tf.py_func
metrics.update({
"approx_bleu_score": _convert_to_eval_metric(
bleu_score)(logits, labels),
"rouge_2_fscore": _convert_to_eval_metric(
rouge_2_fscore)(logits, labels),
"rouge_L_fscore": _convert_to_eval_metric(
rouge_l_fscore)(logits, labels),
})
# Prefix each of the metric names with "metrics/". This allows the metric
# graphs to display under the "metrics" category in TensorBoard.
metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
return metrics
def padded_accuracy(logits, labels):
"""Percentage of times that predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
padded_labels = tf.cast(labels, tf.int32)
return tf.cast(tf.equal(outputs, padded_labels), tf.float32), weights
def padded_accuracy_topk(logits, labels, k):
"""Percentage of times that top-k predictions matches labels on non-0s."""
with tf.variable_scope("padded_accuracy_topk", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
effective_k = tf.minimum(k, tf.shape(logits)[-1])
_, outputs = tf.nn.top_k(logits, k=effective_k)
outputs = tf.cast(outputs, tf.int32)
padded_labels = tf.cast(labels, tf.int32)
padded_labels = tf.expand_dims(padded_labels, axis=-1)
padded_labels += tf.zeros_like(outputs) # Pad to same shape.
same = tf.cast(tf.equal(outputs, padded_labels), tf.float32)
same_topk = tf.reduce_sum(same, axis=-1)
return same_topk, weights
def padded_accuracy_top5(logits, labels):
return padded_accuracy_topk(logits, labels, 5)
def padded_sequence_accuracy(logits, labels):
"""Percentage of times that predictions matches labels everywhere (non-0)."""
with tf.variable_scope("padded_sequence_accuracy", values=[logits, labels]):
logits, labels = _pad_tensors_to_same_length(logits, labels)
weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
padded_labels = tf.cast(labels, tf.int32)
not_correct = (tf.cast(tf.not_equal(outputs, padded_labels), tf.float32) *
weights)
axis = list(range(1, len(outputs.get_shape())))
correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
return correct_seq, tf.constant(1.0)
def padded_neg_log_perplexity(logits, labels, vocab_size):
"""Average log-perplexity excluding padding 0s. No smoothing."""
num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
return -num, den
def bleu_score(logits, labels):
"""Approximate BLEU score computation between labels and predictions.
An approximate BLEU scoring method since we do not glue word pieces or
decode the ids and tokenize the output. By default, we use ngram order of 4
and use brevity penalty. Also, this does not have beam search.
Args:
logits: Tensor of size [batch_size, length_logits, vocab_size]
labels: Tensor of size [batch-size, length_labels]
Returns:
bleu: int, approx bleu score
"""
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
return bleu, tf.constant(1.0)
def _get_ngrams_with_counter(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in xrange(1, max_order + 1):
for i in xrange(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
ngram]
precisions = [0] * max_order
smooth = 1.0
for i in xrange(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = translation_length / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
def rouge_2_fscore(logits, labels):
"""ROUGE-2 F1 score computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
logits: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge2_fscore: approx rouge-2 f1 score.
"""
predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
# TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
return rouge_2_f_score, tf.constant(1.0)
def _get_ngrams(n, text):
"""Calculates n-grams.
Args:
n: which n-grams to calculate
text: An array of tokens
Returns:
A set of n-grams
"""
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def rouge_n(eval_sentences, ref_sentences, n=2):
"""Computes ROUGE-N f1 score of two text collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Args:
eval_sentences: Predicted sentences.
ref_sentences: Sentences from the reference set
n: Size of ngram. Defaults to 2.
Returns:
f1 score for ROUGE-N
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
eval_ngrams = _get_ngrams(n, eval_sentence)
ref_ngrams = _get_ngrams(n, ref_sentence)
ref_count = len(ref_ngrams)
eval_count = len(eval_ngrams)
# Count the overlapping ngrams between evaluated and reference
overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
overlapping_count = len(overlapping_ngrams)
# Handle edge case. This isn't mathematically correct, but it's good enough
if eval_count == 0:
precision = 0.0
else:
precision = float(overlapping_count) / eval_count
if ref_count == 0:
recall = 0.0
else:
recall = float(overlapping_count) / ref_count
f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))
# return overlapping_count / reference_count
return np.mean(f1_scores, dtype=np.float32)
def rouge_l_fscore(predictions, labels):
"""ROUGE scores computation between labels and predictions.
This is an approximate ROUGE scoring method since we do not glue word pieces
or decode the ids and tokenize the output.
Args:
predictions: tensor, model predictions
labels: tensor, gold output.
Returns:
rouge_l_fscore: approx rouge-l f1 score.
"""
outputs = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels),
tf.float32)
return rouge_l_f_score, tf.constant(1.0)
def rouge_l_sentence_level(eval_sentences, ref_sentences):
"""Computes ROUGE-L (sentence level) of two collections of sentences.
Source: https://www.microsoft.com/en-us/research/publication/
rouge-a-package-for-automatic-evaluation-of-summaries/
Calculated according to:
R_lcs = LCS(X,Y)/m
P_lcs = LCS(X,Y)/n
F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
where:
X = reference summary
Y = Candidate summary
m = length of reference summary
n = length of candidate summary
Args:
eval_sentences: The sentences that have been picked by the summarizer
ref_sentences: The sentences from the reference set
Returns:
A float: F_lcs
"""
f1_scores = []
for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
m = float(len(ref_sentence))
n = float(len(eval_sentence))
lcs = _len_lcs(eval_sentence, ref_sentence)
f1_scores.append(_f_lcs(lcs, m, n))
return np.mean(f1_scores, dtype=np.float32)
def _len_lcs(x, y):
"""Returns the length of the Longest Common Subsequence between two seqs.
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: sequence of words
y: sequence of words
Returns
integer: Length of LCS between x and y
"""
table = _lcs(x, y)
n, m = len(x), len(y)
return table[n, m]
def _lcs(x, y):
"""Computes the length of the LCS between two seqs.
The implementation below uses a DP programming algorithm and runs
in O(nm) time where n = len(x) and m = len(y).
Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
Args:
x: collection of words
y: collection of words
Returns:
Table of dictionary of coord and len lcs
"""
n, m = len(x), len(y)
table = dict()
for i in range(n + 1):
for j in range(m + 1):
if i == 0 or j == 0:
table[i, j] = 0
elif x[i - 1] == y[j - 1]:
table[i, j] = table[i - 1, j - 1] + 1
else:
table[i, j] = max(table[i - 1, j], table[i, j - 1])
return table
def _f_lcs(llcs, m, n):
"""Computes the LCS-based F-measure score.
Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
rouge-working-note-v1.3.1.pdf
Args:
llcs: Length of LCS
m: number of words in reference summary
n: number of words in candidate summary
Returns:
Float. LCS-based F-measure score
"""
r_lcs = llcs / m
p_lcs = llcs / n
beta = p_lcs / (r_lcs + 1e-12)
num = (1 + (beta ** 2)) * r_lcs * p_lcs
denom = r_lcs + ((beta ** 2) * p_lcs)
f_lcs = num / (denom + 1e-12)
return f_lcs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines Subtokenizer class to encode and decode strings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import sys
import unicodedata
from absl import logging
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
# pylint: disable=g-complex-comprehension
PAD = "<pad>"
PAD_ID = 0
EOS = "<EOS>"
EOS_ID = 1
RESERVED_TOKENS = [PAD, EOS]
# Set of characters that will be used in the function _escape_token() (see func
# docstring for more details).
# This set is added to the alphabet list to ensure that all escaped tokens can
# be encoded.
_ESCAPE_CHARS = set(u"\\_u;0123456789")
# Regex for the function _unescape_token(), the inverse of _escape_token().
# This is used to find "\u", "\\", and "\###;" substrings in the token.
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_UNDEFINED_UNICODE = u"\u3013"
def alphanumeric_char_set():
return set(
six.unichr(i)
for i in xrange(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
# Set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set()
# min_count is the minimum number of times a subtoken must appear in the data
# before before it is added to the vocabulary. The value is found using binary
# search to obtain the target vocabulary size.
_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
class Subtokenizer(object):
"""Encodes and decodes strings to/from integer IDs."""
def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None):
"""Initializes class, creating a vocab file if data_files is provided."""
logging.info("Initializing Subtokenizer from file %s.", vocab_file)
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens)
self.alphabet = _generate_alphabet_dict(self.subtoken_list)
self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list)
self.max_subtoken_length = 0
for subtoken in self.subtoken_list:
self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
# Create cache to speed up subtokenization
self._cache_size = 2**20
self._cache = [(None, None)] * self._cache_size
self._master_char_set = master_char_set
@staticmethod
def init_from_files(vocab_file,
files,
target_vocab_size,
threshold,
min_count=None,
file_byte_limit=1e6,
reserved_tokens=None,
correct_strip=True,
master_char_set=None):
"""Create subtoken vocabulary based on files, and save vocab to file.
Args:
vocab_file: String name of vocab file to store subtoken vocabulary.
files: List of file paths that will be used to generate vocabulary.
target_vocab_size: target vocabulary size to generate.
threshold: int threshold of vocabulary size to accept.
min_count: int minimum count to use for generating the vocabulary. The min
count is the minimum number of times a subtoken should appear in the
files before it is added to the vocabulary. If set to none, this value
is found using binary search.
file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that
will be drawn from the files.
reserved_tokens: List of string tokens that are guaranteed to be at the
beginning of the subtoken vocabulary list.
correct_strip: Whether to convert text to unicode before strip.
master_char_set: the char set.
Returns:
Subtokenizer object
"""
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if tf.io.gfile.exists(vocab_file):
logging.info("Vocab file already exists (%s)", vocab_file)
else:
logging.info("Begin steps to create subtoken vocabulary...")
token_counts = _count_tokens(files, file_byte_limit, correct_strip,
master_char_set)
alphabet = _generate_alphabet_dict(token_counts)
subtoken_list = _generate_subtokens_with_target_vocab_size(
token_counts, alphabet, target_vocab_size, threshold, min_count,
reserved_tokens)
logging.info("Generated vocabulary with %d subtokens.",
len(subtoken_list))
_save_vocab_file(vocab_file, subtoken_list)
return Subtokenizer(vocab_file, master_char_set=master_char_set)
def encode(self, raw_string, add_eos=False):
"""Encodes a string into a list of int subtoken ids."""
ret = []
tokens = _split_string_to_tokens(
native_to_unicode(raw_string), self._master_char_set)
for token in tokens:
ret.extend(self._token_to_subtoken_ids(token))
if add_eos:
assert EOS in self.subtoken_list, \
"Can't append 'EOS' because it is not in list of known subtokens."
ret.append(EOS_ID)
return ret
def _token_to_subtoken_ids(self, token):
"""Encode a single token into a list of subtoken ids."""
cache_location = hash(token) % self._cache_size
cache_key, cache_value = self._cache[cache_location]
if cache_key == token:
return cache_value
ret = _split_token_to_subtokens(
_escape_token(token, self.alphabet), self.subtoken_to_id_dict,
self.max_subtoken_length)
ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret]
self._cache[cache_location] = (token, ret)
return ret
def decode(self, subtokens):
"""Converts list of int subtokens ids into a string."""
if isinstance(subtokens, np.ndarray):
# Note that list(subtokens) converts subtokens to a python list, but the
# items remain as np.int32. This converts both the array and its items.
subtokens = subtokens.tolist()
if not subtokens:
return ""
assert isinstance(subtokens, list) and isinstance(subtokens[0], int), (
"Subtokens argument passed into decode() must be a list of integers.")
return _unicode_to_native(
_join_tokens_to_string(
self._subtoken_ids_to_tokens(subtokens), self._master_char_set))
def _subtoken_ids_to_tokens(self, subtokens):
"""Convert list of int subtoken ids to a list of string tokens."""
escaped_tokens = "".join([
self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list)
])
escaped_tokens = escaped_tokens.split("_")
# All tokens in the vocabulary list have been escaped (see _escape_token())
# so each token must be unescaped when decoding.
ret = []
for token in escaped_tokens:
if token:
ret.append(_unescape_token(token))
return ret
def _save_vocab_file(vocab_file, subtoken_list):
"""Save subtokens to file."""
with tf.io.gfile.GFile(vocab_file, mode="w") as f:
for subtoken in subtoken_list:
f.write("'%s'\n" % _unicode_to_native(subtoken))
def _load_vocab_file(vocab_file, reserved_tokens=None):
"""Load vocabulary while ensuring reserved tokens are at the top."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
subtoken_list = []
with tf.io.gfile.GFile(vocab_file, mode="r") as f:
for line in f:
subtoken = native_to_unicode(line.strip())
subtoken = subtoken[1:-1] # Remove surrounding single-quotes
if subtoken in reserved_tokens:
continue
subtoken_list.append(native_to_unicode(subtoken))
return reserved_tokens + subtoken_list
def native_to_unicode(s):
"""Convert string to unicode (required in Python 2)."""
try: # Python 2
return s if isinstance(s, unicode) else s.decode("utf-8")
except NameError: # Python 3
return s
def _unicode_to_native(s):
"""Convert string from unicode to native format (required in Python 2)."""
try: # Python 2
return s.encode("utf-8") if isinstance(s, unicode) else s
except NameError: # Python 3
return s
def _split_string_to_tokens(text, master_char_set):
"""Splits text to a list of string tokens."""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_master = [c in master_char_set for c in text]
for pos in xrange(1, len(text)):
if is_master[pos] != is_master[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
token_start = pos
final_token = text[token_start:]
ret.append(final_token)
return ret
def _join_tokens_to_string(tokens, master_char_set):
"""Join a list of string tokens into a single string."""
token_is_master = [t[0] in master_char_set for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_master[i - 1] and token_is_master[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
def _escape_token(token, alphabet):
r"""Replace characters that aren't in the alphabet and append "_" to token.
Apply three transformations to the token:
1. Replace underline character "_" with "\u", and backslash "\" with "\\".
2. Replace characters outside of the alphabet with "\###;", where ### is the
character's Unicode code point.
3. Appends "_" to mark the end of a token.
Args:
token: unicode string to be escaped
alphabet: list of all known characters
Returns:
escaped string
"""
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
return u"".join(ret) + "_"
def _unescape_token(token):
r"""Replaces escaped characters in the token with their unescaped versions.
Applies inverse transformations as _escape_token():
1. Replace "\u" with "_", and "\\" with "\".
2. Replace "\###;" with the unicode character the ### refers to.
Args:
token: escaped string
Returns:
unescaped string
"""
def match(m):
r"""Returns replacement string for matched object.
Matched objects contain one of the strings that matches the regex pattern:
r"\\u|\\\\|\\([0-9]+);"
The strings can be '\u', '\\', or '\###;' (### is any digit number).
m.group(0) refers to the entire matched string ('\u', '\\', or '\###;').
m.group(1) refers to the first parenthesized subgroup ('###').
m.group(0) exists for all match objects, while m.group(1) exists only for
the string '\###;'.
This function looks to see if m.group(1) exists. If it doesn't, then the
matched string must be '\u' or '\\' . In this case, the corresponding
replacement ('_' and '\') are returned. Note that in python, a single
backslash is written as '\\', and double backslash as '\\\\'.
If m.goup(1) exists, then use the integer in m.group(1) to return a
unicode character.
Args:
m: match object
Returns:
String to replace matched object with.
"""
# Check if the matched strings are '\u' or '\\'.
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
# If m.group(1) exists, try and return unicode character.
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return _UNDEFINED_UNICODE
# Use match function to replace escaped substrings in the token.
return _UNESCAPE_REGEX.sub(match, token)
def _count_tokens(files,
file_byte_limit=1e6,
correct_strip=True,
master_char_set=None):
"""Return token counts of words in the files.
Samples file_byte_limit bytes from each file, and counts the words that appear
in the samples. The samples are semi-evenly distributed across the file.
Args:
files: List of filepaths
file_byte_limit: Max number of bytes that will be read from each file.
correct_strip: Whether to convert text to unicode before strip. This affects
vocabulary generation for PY2. Sets correct_strip to False in PY2 to
reproduce previous common public result. Sets correct_strip to True will
let PY2 and PY3 get a consistent vocabulary.
master_char_set: the char set.
Returns:
Dictionary mapping tokens to the number of times they appear in the sampled
lines from the files.
"""
if master_char_set is None:
master_char_set = _ALPHANUMERIC_CHAR_SET
token_counts = collections.defaultdict(int)
for filepath in files:
with tf.io.gfile.GFile(filepath, mode="r") as reader:
file_byte_budget = file_byte_limit
counter = 0
lines_to_skip = int(reader.size() / (file_byte_budget * 2))
for line in reader:
if counter < lines_to_skip:
counter += 1
else:
if file_byte_budget < 0:
break
if correct_strip:
line = native_to_unicode(line)
line = line.strip()
file_byte_budget -= len(line)
counter = 0
# Add words to token counts
for token in _split_string_to_tokens(
native_to_unicode(line), master_char_set):
token_counts[token] += 1
return token_counts
def _list_to_index_dict(lst):
"""Create dictionary mapping list items to their indices in the list."""
return {item: n for n, item in enumerate(lst)}
def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
"""Splits a token into subtokens defined in the subtoken dict."""
ret = []
start = 0
token_len = len(token)
while start < token_len:
# Find the longest subtoken, so iterate backwards.
for end in xrange(min(token_len, start + max_subtoken_length), start, -1):
subtoken = token[start:end]
if subtoken in subtoken_dict:
ret.append(subtoken)
start = end
break
else: # Did not break
# If there is no possible encoding of the escaped token then one of the
# characters in the token is not in the alphabet. This should be
# impossible and would be indicative of a bug.
raise ValueError("Was unable to split token \"%s\" into subtokens." %
token)
return ret
def _generate_subtokens_with_target_vocab_size(token_counts,
alphabet,
target_size,
threshold,
min_count=None,
reserved_tokens=None):
"""Generate subtoken vocabulary close to the target size."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
if min_count is not None:
logging.info("Using min_count=%d to generate vocab with target size %d",
min_count, target_size)
return _generate_subtokens(
token_counts, alphabet, min_count, reserved_tokens=reserved_tokens)
def bisect(min_val, max_val):
"""Recursive function to binary search for subtoken vocabulary."""
cur_count = (min_val + max_val) // 2
logging.info("Binary search: trying min_count=%d (%d %d)", cur_count,
min_val, max_val)
subtoken_list = _generate_subtokens(
token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens)
val = len(subtoken_list)
logging.info("Binary search: min_count=%d resulted in %d tokens", cur_count,
val)
within_threshold = abs(val - target_size) < threshold
if within_threshold or min_val >= max_val or cur_count < 2:
return subtoken_list
if val > target_size:
other_subtoken_list = bisect(cur_count + 1, max_val)
else:
other_subtoken_list = bisect(min_val, cur_count - 1)
# Return vocabulary dictionary with the closest number of tokens.
other_val = len(other_subtoken_list)
if abs(other_val - target_size) < abs(val - target_size):
return other_subtoken_list
return subtoken_list
logging.info("Finding best min_count to get target size of %d", target_size)
return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT)
def _generate_alphabet_dict(iterable, reserved_tokens=None):
"""Create set of characters that appear in any element in the iterable."""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
alphabet = {c for token in iterable for c in token}
alphabet |= {c for token in reserved_tokens for c in token}
alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set.
return alphabet
def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict,
max_subtoken_length):
"""Count number of times subtokens appear, and generate new subtokens.
Args:
token_counts: dict mapping tokens to the number of times they appear in the
original files.
alphabet: list of allowed characters. Used to escape the tokens, which
guarantees that all tokens can be split into subtokens.
subtoken_dict: dict mapping subtokens to ids.
max_subtoken_length: maximum length of subtoken in subtoken_dict.
Returns:
A defaultdict mapping subtokens to the number of times they appear in the
tokens. The dict may contain new subtokens.
"""
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
token = _escape_token(token, alphabet)
subtokens = _split_token_to_subtokens(token, subtoken_dict,
max_subtoken_length)
# Generate new subtokens by taking substrings from token.
start = 0
for subtoken in subtokens:
for end in xrange(start + 1, len(token) + 1):
new_subtoken = token[start:end]
subtoken_counts[new_subtoken] += count
start += len(subtoken)
return subtoken_counts
def _filter_and_bucket_subtokens(subtoken_counts, min_count):
"""Return a bucketed list of subtokens that are filtered by count.
Args:
subtoken_counts: defaultdict mapping subtokens to their counts
min_count: int count used to filter subtokens
Returns:
List of subtoken sets, where subtokens in set i have the same length=i.
"""
# Create list of buckets, where subtokens in bucket i have length i.
subtoken_buckets = []
for subtoken, count in six.iteritems(subtoken_counts):
if count < min_count: # Filter out subtokens that don't appear enough
continue
while len(subtoken_buckets) <= len(subtoken):
subtoken_buckets.append(set())
subtoken_buckets[len(subtoken)].add(subtoken)
return subtoken_buckets
def _gen_new_subtoken_list(subtoken_counts,
min_count,
alphabet,
reserved_tokens=None):
"""Generate candidate subtokens ordered by count, and new max subtoken length.
Add subtokens to the candiate list in order of length (longest subtokens
first). When a subtoken is added, the counts of each of its prefixes are
decreased. Prefixes that don't appear much outside the subtoken are not added
to the candidate list.
For example:
subtoken being added to candidate list: 'translate'
subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...}
min_count: 5
When 'translate' is added, subtoken_counts is updated to:
{'translate':0, 't':30, 'tr':6, 'tra': 2, ...}
The subtoken 'tra' will not be added to the candidate list, because it appears
twice (less than min_count) outside of 'translate'.
Args:
subtoken_counts: defaultdict mapping str subtokens to int counts
min_count: int minumum count requirement for subtokens
alphabet: set of characters. Each character is added to the subtoken list to
guarantee that all tokens can be encoded.
reserved_tokens: list of tokens that will be added to the beginning of the
returned subtoken list.
Returns:
List of candidate subtokens in decreasing count order, and maximum subtoken
length
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Create a list of (count, subtoken) for each candidate subtoken.
subtoken_candidates = []
# Use bucketted list to iterate through subtokens in order of length.
# subtoken_buckets[i] = set(subtokens), where each subtoken has length i.
subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count)
max_subtoken_length = len(subtoken_buckets) - 1
# Go through the list in reverse order to consider longer subtokens first.
for subtoken_len in xrange(max_subtoken_length, 0, -1):
for subtoken in subtoken_buckets[subtoken_len]:
count = subtoken_counts[subtoken]
# Possible if this subtoken is a prefix of another token.
if count < min_count:
continue
# Ignore alphabet/reserved tokens, which will be added manually later.
if subtoken not in alphabet and subtoken not in reserved_tokens:
subtoken_candidates.append((count, subtoken))
# Decrement count of the subtoken's prefixes (if a longer subtoken is
# added, its prefixes lose priority to be added).
for end in xrange(1, subtoken_len):
subtoken_counts[subtoken[:end]] -= count
# Add alphabet subtokens (guarantees that all strings are encodable).
subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet)
# Order subtoken candidates by decreasing count.
subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)]
# Add reserved tokens to beginning of the list.
subtoken_list = reserved_tokens + subtoken_list
return subtoken_list, max_subtoken_length
def _generate_subtokens(token_counts,
alphabet,
min_count,
num_iterations=4,
reserved_tokens=None):
"""Create a list of subtokens in decreasing order of frequency.
Args:
token_counts: dict mapping str tokens -> int count
alphabet: set of characters
min_count: int minimum number of times a subtoken must appear before it is
added to the vocabulary.
num_iterations: int number of iterations to generate new tokens.
reserved_tokens: list of tokens that will be added to the beginning to the
returned subtoken list.
Returns:
Sorted list of subtokens (most frequent first)
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
# Use alphabet set to create initial list of subtokens
subtoken_list = reserved_tokens + list(alphabet)
max_subtoken_length = 1
# On each iteration, segment all words using the subtokens defined in
# subtoken_dict, count how often the resulting subtokens appear, and update
# the dictionary with subtokens w/ high enough counts.
for i in xrange(num_iterations):
logging.info("\tGenerating subtokens: iteration %d", i)
# Generate new subtoken->id dictionary using the new subtoken list.
subtoken_dict = _list_to_index_dict(subtoken_list)
# Create dict mapping subtoken->count, with additional subtokens created
# from substrings taken from the tokens.
subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet,
subtoken_dict,
max_subtoken_length)
# Generate new list of subtokens sorted by subtoken count.
subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
logging.info("\tVocab size: %d", len(subtoken_list))
return subtoken_list
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Subtokenizer and string helper methods."""
import collections
import tempfile
import tensorflow as tf
from official.legacy.transformer.utils import tokenizer
class SubtokenizerTest(tf.test.TestCase):
def _init_subtokenizer(self, vocab_list):
temp_file = tempfile.NamedTemporaryFile(delete=False)
with tf.io.gfile.GFile(temp_file.name, "w") as w:
for subtoken in vocab_list:
w.write("'%s'" % subtoken)
w.write("\n")
return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[])
def test_encode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
s = "testing 123"
encoded_list = subtokenizer.encode(s)
self.assertEqual([1, 2, 0], encoded_list)
def test_decode(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
decoded_str = subtokenizer.decode(encoded_list)
self.assertEqual("testing 123", decoded_str)
def test_subtoken_ids_to_tokens(self):
vocab_list = ["123_", "test", "ing_"]
subtokenizer = self._init_subtokenizer(vocab_list)
encoded_list = [1, 2, 0] # testing 123
token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list)
self.assertEqual([u"testing", u"123"], token_list)
class StringHelperTest(tf.test.TestCase):
def test_split_string_to_tokens(self):
text = "test? testing 123."
tokens = tokenizer._split_string_to_tokens(text,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
def test_join_tokens_to_string(self):
tokens = ["test", "? ", "testing", "123", "."]
s = tokenizer._join_tokens_to_string(tokens,
tokenizer._ALPHANUMERIC_CHAR_SET)
self.assertEqual("test? testing 123.", s)
def test_escape_token(self):
token = u"abc_\\4"
alphabet = set("abc_\\u;")
escaped_token = tokenizer._escape_token(token, alphabet)
self.assertEqual("abc\\u\\\\\\52;_", escaped_token)
def test_unescape_token(self):
escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
unescaped_token = tokenizer._unescape_token(escaped_token)
self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
def test_list_to_index_dict(self):
lst = ["test", "strings"]
d = tokenizer._list_to_index_dict(lst)
self.assertDictEqual({"test": 0, "strings": 1}, d)
def test_split_token_to_subtokens(self):
token = "abc"
subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
max_subtoken_length = 2
subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict,
max_subtoken_length)
self.assertEqual(["ab", "c"], subtokens)
def test_generate_alphabet_dict(self):
s = ["testing", "123"]
reserved_tokens = ["???"]
alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens)
self.assertIn("?", alphabet)
self.assertIn("t", alphabet)
self.assertIn("e", alphabet)
self.assertIn("s", alphabet)
self.assertIn("i", alphabet)
self.assertIn("n", alphabet)
self.assertIn("g", alphabet)
self.assertIn("1", alphabet)
self.assertIn("2", alphabet)
self.assertIn("3", alphabet)
def test_count_and_gen_subtokens(self):
token_counts = {"abc": 5}
alphabet = set("abc_")
subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3}
max_subtoken_length = 2
subtoken_counts = tokenizer._count_and_gen_subtokens(
token_counts, alphabet, subtoken_dict, max_subtoken_length)
self.assertIsInstance(subtoken_counts, collections.defaultdict)
self.assertDictEqual(
{
"a": 5,
"b": 5,
"c": 5,
"_": 5,
"ab": 5,
"bc": 5,
"c_": 5,
"abc": 5,
"bc_": 5,
"abc_": 5
}, subtoken_counts)
def test_filter_and_bucket_subtokens(self):
subtoken_counts = collections.defaultdict(int, {
"a": 2,
"b": 4,
"c": 1,
"ab": 6,
"ac": 3,
"abbc": 5
})
min_count = 3
subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
subtoken_counts, min_count)
self.assertEqual(len(subtoken_buckets[0]), 0)
self.assertEqual(set("b"), subtoken_buckets[1])
self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2])
self.assertEqual(len(subtoken_buckets[3]), 0)
self.assertEqual(set(["abbc"]), subtoken_buckets[4])
def test_gen_new_subtoken_list(self):
subtoken_counts = collections.defaultdict(int, {
"translate": 10,
"t": 40,
"tr": 16,
"tra": 12
})
min_count = 5
alphabet = set("translate")
reserved_tokens = ["reserved", "tokens"]
subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list(
subtoken_counts, min_count, alphabet, reserved_tokens)
# Check that "tra" isn"t in the list (its count should be decremented to 2,
# so it should not be added to the canddiate list).
self.assertNotIn("tra", subtoken_list)
self.assertIn("tr", subtoken_list)
self.assertIn("t", subtoken_list)
self.assertEqual(len("translate"), max_token_length)
def test_generate_subtokens(self):
token_counts = {"ab": 1, "bc": 3, "abc": 5}
alphabet = set("abc_")
min_count = 100
num_iterations = 1
reserved_tokens = ["reserved", "tokens"]
vocab_list = tokenizer._generate_subtokens(token_counts, alphabet,
min_count, num_iterations,
reserved_tokens)
# Check that reserved tokens are at the front of the list
self.assertEqual(vocab_list[:2], reserved_tokens)
# Check that each character in alphabet is in the vocab list
for c in alphabet:
self.assertIn(c, vocab_list)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stacking model horizontally."""
from absl import logging
import numpy as np
import tensorflow as tf
def expand_vector(v: np.ndarray) -> np.ndarray:
"""Expands a vector with batch dimensions.
Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
Args:
v: A vector with shape [..., a].
Returns:
A vector with shape [..., 2 * a].
"""
return np.repeat(v, 2, axis=-1)
def expand_1_axis(w: np.ndarray,
epsilon: float,
axis: int) -> np.ndarray:
"""Expands either the first dimension or the last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
axis: Must be either 0 or -1.
Returns:
Expanded numpy array.
"""
assert axis in (0, -1), (
"Only support expanding the first or the last dimension. "
"Got: {}".format(axis))
rank = len(w.shape)
d_w = np.random.normal(np.zeros_like(w), np.fabs(w) * epsilon, w.shape)
d_w = np.repeat(d_w, 2, axis=axis)
sign_flip = np.array([1, -1])
for _ in range(rank - 1):
sign_flip = np.expand_dims(sign_flip, axis=-1 if axis == 0 else 0)
sign_flip = np.tile(sign_flip,
[w.shape[0]] + [1] * (rank - 2) + [w.shape[-1]])
d_w *= sign_flip
w_expand = (np.repeat(w, 2, axis=axis) + d_w) / 2
return w_expand
def expand_2_axes(w: np.ndarray,
epsilon: float) -> np.ndarray:
"""Expands the first dimension and the last dimension of w.
The following constraint will be satisfied:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
Returns:
Expanded numpy array.
"""
rank = len(w.shape)
d_w = np.random.normal(np.zeros_like(w), np.fabs(w) * epsilon, w.shape)
d_w = np.repeat(np.repeat(d_w, 2, axis=0), 2, axis=-1)
sign_flip = np.array([1, -1])
for _ in range(rank - 1):
sign_flip = np.expand_dims(sign_flip, axis=-1)
sign_flip = np.tile(sign_flip,
[w.shape[0]] + [1] * (rank - 2) + [w.shape[-1] * 2])
d_w *= sign_flip
w_expand = (np.repeat(np.repeat(w, 2, axis=0), 2, axis=-1) + d_w) / 2
return w_expand
def var_to_var(var_from: tf.Variable,
var_to: tf.Variable,
epsilon: float):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Not that there will be noise added to the left hand side, if epsilon != 0.
If the shape of `var_to` is (2 * a, ..., z):
For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
If the shape of `var_to` is (2 * a, ..., 2 * z):
For any x, tf.matmul(expand_vector(x), var_to) ==
expand_vector(tf.matmul(expand_vector(x), var_from))
Args:
var_from: input variable to expand.
var_to: output variable.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
shape_from = var_from.shape
shape_to = var_to.shape
if shape_from == shape_to:
var_to.assign(var_from)
elif len(shape_from) == 1 and len(shape_to) == 1:
var_to.assign(expand_vector(var_from.numpy()))
elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=0))
elif shape_from[0] == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=-1))
elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_2_axes(var_from.numpy(), epsilon=epsilon))
else:
raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))
def model_to_model_2x_wide(model_from: tf.Module,
model_to: tf.Module,
epsilon: float = 0.1):
"""Expands a model to a wider version.
Also makes sure that the output of the model is not changed after expanding.
For example:
```
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
model_to_model_2x_wide(model_narrow, model_wide)
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
Args:
model_from: input model to expand.
model_to: output model whose variables will be assigned expanded values
according to `model_from`.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
for w_from, w_to in zip(model_from.trainable_variables,
model_to.trainable_variables):
logging.info("expanding %s %s to %s %s",
w_from.name, w_from.shape, w_to.name, w_to.shape)
var_to_var(w_from, w_to, epsilon=epsilon)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf2_utils_2x_wide."""
import numpy as np
import tensorflow as tf
from official.modeling.fast_training.experimental import tf2_utils_2x_wide
class Tf2Utils2XWideTest(tf.test.TestCase):
def test_expand_vector(self):
x = np.array([1, 2])
self.assertAllClose(tf2_utils_2x_wide.expand_vector(x),
np.array([1, 1, 2, 2]))
def test_expand_matrix(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_2_axes(x, epsilon=0.1)
self.assertAllClose(x[0, :] + x[1, :], np.array([1, 1, 2, 2]))
self.assertAllClose(x[2, :] + x[3, :], np.array([3, 3, 4, 4]))
def test_expand_matrix_axis_0(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_1_axis(x, axis=0, epsilon=0.1)
self.assertAllClose(x[0, :] + x[1, :], np.array([1, 2]))
self.assertAllClose(x[2, :] + x[3, :], np.array([3, 4]))
def test_expand_matrix_axis_1(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_1_axis(x, axis=-1, epsilon=0.1)
self.assertAllClose(x[:, 0] + x[:, 1], np.array([1, 3]))
self.assertAllClose(x[:, 2] + x[:, 3], np.array([2, 4]))
def test_expand_3d_tensor(self):
x0 = np.array([10, 11])
x1 = np.array([10, 10, 11, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_2_axes(w0, epsilon=0.1)
o0 = np.matmul(x0, w0)
o1 = np.matmul(x1, w1)
self.assertAllClose(np.repeat(o0, 2, axis=-1), o1)
def test_expand_3d_tensor_axis_0(self):
x0 = np.array([10, 11])
x1 = np.array([10, 10, 11, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=0, epsilon=0.1)
o0 = np.matmul(x0, w0)
o1 = np.matmul(x1, w1)
self.assertAllClose(o0, o1)
def test_expand_3d_tensor_axis_2(self):
x = np.array([10, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=-1, epsilon=0.1)
o0 = np.matmul(x, w0)
o1 = np.matmul(x, w1)
self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1))
def test_end_to_end(self):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
x0 = np.array([[1, 2, 3]])
x1 = np.array([[1, 1, 2, 2, 3, 3]])
# Call model once to build variables first.
_, _ = model_narrow(x0), model_wide(x1)
tf2_utils_2x_wide.model_to_model_2x_wide(
model_narrow, model_wide, epsilon=0.2)
self.assertAllClose(model_narrow(x0), model_wide(x1),
rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import abc
import dataclasses
from typing import Any, Mapping
from absl import logging
import six
import tensorflow as tf
from official.common import streamz_counters
from official.modeling.fast_training.progressive import utils
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ProgressiveConfig(base_config.Config):
pass
@six.add_metaclass(abc.ABCMeta)
class ProgressivePolicy:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def __init__(self):
"""Initialize stage policy."""
self._cur_train_dataset = None
self._cur_eval_dataset = None
self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)
stage_id = 0
self._stage_id = tf.Variable(
stage_id,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None)) # pytype: disable=wrong-arg-types # typed-keras
streamz_counters.progressive_policy_creation_counter.get_cell(
).increase_by(1)
def compute_stage_id(self, global_step: int) -> int:
for stage_id in range(self.num_stages()):
global_step -= self.num_steps(stage_id)
if global_step < 0:
return stage_id
logging.error('Global step %d found no matching progressive stages. '
'Default to the last stage.', global_step)
return self.num_stages() - 1
@abc.abstractmethod
def num_stages(self) -> int:
"""Return the total number of progressive stages."""
pass
@abc.abstractmethod
def num_steps(self, stage_id: int) -> int:
"""Return the total number of steps in this stage."""
pass
@abc.abstractmethod
def get_model(self,
stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@abc.abstractmethod
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
"""Return optimizer for this stage."""
pass
@abc.abstractmethod
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return training Dataset for this stage."""
pass
@abc.abstractmethod
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return evaluation Dataset for this stage."""
pass
@property
def cur_model(self) -> tf.keras.Model:
return self._volatiles.model
@property
def cur_train_dataset(self) -> tf.data.Dataset:
if self._cur_train_dataset is None:
self._cur_train_dataset = self.get_train_dataset(self._stage_id.numpy())
return self._cur_train_dataset
@property
def cur_eval_dataset(self) -> tf.data.Dataset:
if self._cur_eval_dataset is None:
self._cur_eval_dataset = self.get_eval_dataset(self._stage_id.numpy())
return self._cur_eval_dataset
@property
def cur_optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._volatiles.optimizer
@property
def is_last_stage(self) -> bool:
stage_id = self._stage_id.numpy()
return stage_id >= self.num_stages() - 1
@property
def cur_checkpoint_items(self) -> Mapping[str, Any]:
return dict(stage_id=self._stage_id, volatiles=self._volatiles)
def is_stage_advancing(self, global_step: int) -> bool:
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
return old_stage_id != new_stage_id
def update_pt_stage(self, global_step: int, pass_old_model=True) -> None:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
logging.info('Switching stage from %d to %d', old_stage_id, new_stage_id)
# Update stage id.
self._stage_id.assign(new_stage_id)
# Update dataset function.
self._cur_train_dataset = None
self._cur_eval_dataset = None
# Update optimizer and model.
new_optimizer = self.get_optimizer(new_stage_id)
self._volatiles.reassign_trackable(optimizer=new_optimizer)
new_model = self.get_model(
new_stage_id, old_model=self.cur_model if pass_old_model else None)
self._volatiles.reassign_trackable(model=new_model)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM binary for the progressive trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.modeling.fast_training.progressive import train_lib
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import train_lib as base_train_lib
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
logging.info('Running progressive trainer.')
trainer = prog_trainer_lib.ProgressiveTrainer(
params, task, ckpt_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return trainer.model, {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive train_lib."""
import os
from absl import flags
from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import train_lib
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=ProgTaskConfig())
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
task = task_factory.get_task(experiment_config.task,
logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import dataclasses
import os
from typing import Any, Optional
# Import libraries
from absl import logging
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import utils
ExperimentConfig = config_definitions.ExperimentConfig
@dataclasses.dataclass
class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None
export_max_to_keep: Optional[int] = None
export_only_final_stage_ckpt: bool = True
@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models."""
def __init__(
self,
config: ExperimentConfig,
prog_task: base_task.Task, # also implemented ProgressivePolicy.
ckpt_dir: str = '',
train: bool = True,
evaluate: bool = True,
checkpoint_exporter: Any = None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._runtime_options = trainer_lib.get_runtime_options(config)
self._task = prog_task
# Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir)
self._export_ckpt_manager = None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self._checkpoint_exporter = checkpoint_exporter
self._global_step = orbit.utils.create_global_step()
self._checkpoint = utils.CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step,
**self._task.cur_checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32)
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
if train:
orbit.StandardTrainer.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardTrainer.
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function))
if evaluate:
orbit.StandardEvaluator.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardEvaluator.
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function))
@property
def model(self):
return self._task.cur_model
@property
def optimizer(self):
return self._task.cur_optimizer
# override
@property
def train_dataset(self):
"""Overriding StandardTrainer.train_dataset."""
return self._task.cur_train_dataset
# override
@train_dataset.setter
def train_dataset(self, _):
raise SyntaxError('Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.')
# override
@property
def eval_dataset(self):
"""Overriding StandardEvaluator.eval_dataset."""
return self._task.cur_eval_dataset
# override
@eval_dataset.setter
def eval_dataset(self, _):
raise SyntaxError('Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.')
def train_loop_end(self):
"""See base class."""
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs['learning_rate'] = self.optimizer.learning_rate
self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
return logs
def _update_pt_stage_from_ckpt(self, ckpt_file):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if not ckpt_file:
return
ckpt = tf.train.Checkpoint(global_step=self.global_step)
ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if not self.config.trainer.export_checkpoint:
logging.info('Not exporting checkpoints.')
return
if not self._task.is_last_stage and (
self.config.trainer.export_only_final_stage_ckpt):
logging.info('Not exporting checkpoints until the last stage.')
return
if self._export_ckpt_manager is None:
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if hasattr(self.model, 'checkpoint_items'):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
max_to_keep = self.config.trainer.export_max_to_keep or (
self.config.trainer.max_to_keep)
checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
self.config.trainer.checkpoint_interval)
self._export_ckpt_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_ckpt_dir,
checkpoint_name='ckpt',
step_counter=self.global_step,
max_to_keep=max_to_keep,
checkpoint_interval=checkpoint_interval,
)
# Make sure we export the last checkpoint.
last_checkpoint = (
self.global_step.numpy() == self._config.trainer.train_steps)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=not last_checkpoint)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import trainer as trainer_lib
from official.nlp.configs import bert
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def get_exp_config():
return cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
"""Just for testing purposes."""
def __init__(self, strategy, task_config, change_train_dataset=True):
self._strategy = strategy
self._change_train_dataset = change_train_dataset
self._my_train_dataset = None
mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
policies.ProgressivePolicy.__init__(self)
def num_stages(self) -> int:
return 2
def num_steps(self, stage_id: int) -> int:
return 2 if stage_id == 0 else 4
def get_model(self,
stage_id: int,
old_model: tf.keras.Model) -> tf.keras.Model:
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
optimizer_config = cfg.OptimizationConfig({
'optimizer': {'type': optimizer_type},
'learning_rate': {'type': 'constant'}})
opt_factory = optimization.OptimizerFactory(optimizer_config)
return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
if not self._change_train_dataset and self._my_train_dataset:
return self._my_train_dataset
if self._strategy:
self._my_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
else:
self._my_train_dataset = self._build_inputs(stage_id)
return self._my_train_dataset
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
if self._strategy:
return orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
return self._build_inputs(stage_id)
def _build_inputs(self, stage_id):
def dummy_data(_):
batch_size = 2 if stage_id == 0 else 1
x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
return dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution, model_dir, change_train_dataset):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(
distribution, self._config.task, change_train_dataset),
ckpt_dir=model_dir)
return trainer
@combinations.generate(all_strategy_combinations())
def test_checkpointing(self, distribution):
model_dir = self.get_temp_dir()
ckpt_file = os.path.join(model_dir, 'ckpt')
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
self.assertTrue(trainer._task.is_last_stage)
trainer.checkpoint.save(ckpt_file)
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.checkpoint.restore(ckpt_file + '-1')
self.assertTrue(trainer._task.is_last_stage)
@combinations.generate(all_strategy_combinations())
def test_train_dataset(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
# Using dataset of stage == 0
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 2)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
# Using dataset of stage == 1
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 1)
with self.assertRaises(SyntaxError):
trainer.train_dataset = None
@combinations.generate(all_strategy_combinations())
def test_train_dataset_no_switch(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, False)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is not reset since the dataset is not changed.
self.assertIsNotNone(trainer._train_iter)
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is reset since the dataset changed.
self.assertIsNone(trainer._train_iter)
class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerWithMaskedLMTaskTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(distribution, self._config.task),
ckpt_dir=self.get_temp_dir())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util classes and functions."""
from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
class VolatileTrackable(tracking.AutoTrackable):
"""A util class to keep Trackables that might change instances."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def reassign_trackable(self, **kwargs):
for k, v in kwargs.items():
delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for grad_utils."""
import tensorflow as tf
from official.modeling import grad_utils
from official.modeling import performance
class GradUtilsTest(tf.test.TestCase):
def test_minimize(self):
optimizer = tf.keras.optimizers.SGD(0.1)
with tf.GradientTape() as tape:
model = tf.keras.layers.Dense(2)
outputs = model(tf.zeros((2, 2), tf.float32))
loss = tf.reduce_mean(outputs)
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
model.trainable_variables)
def test_minimize_fp16(self):
optimizer = performance.configure_optimizer(
tf.keras.optimizers.SGD(0.1), use_float16=True)
performance.set_mixed_precision_policy(tf.float16)
with tf.GradientTape() as tape:
model = tf.keras.layers.Dense(2)
outputs = model(tf.zeros((2, 2), tf.float16))
loss = tf.reduce_mean(outputs)
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
model.trainable_variables)
# Test other fp16 settings.
def _clip_by_global_norm(grads_and_vars):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
return zip(grads, tvars)
with tf.GradientTape() as tape:
model = tf.keras.layers.Dense(2)
outputs = model(tf.zeros((2, 2), tf.float16))
loss = tf.reduce_mean(outputs)
optimizer = performance.configure_optimizer(
tf.keras.optimizers.SGD(0.1), use_float16=True, loss_scale=128)
grad_utils.minimize_using_explicit_allreduce(
tape,
optimizer,
loss,
model.trainable_variables,
pre_allreduce_callbacks=[_clip_by_global_norm],
post_allreduce_callbacks=[_clip_by_global_norm])
def test_set_mixed_precision_policy(self):
performance.set_mixed_precision_policy(tf.float16)
performance.set_mixed_precision_policy(tf.bfloat16)
performance.set_mixed_precision_policy(tf.float32)
with self.assertRaises(ValueError):
performance.set_mixed_precision_policy(tf.int32)
if __name__ == '__main__':
tf.test.main()
......@@ -14,14 +14,19 @@
"""Functions and classes related to training performance."""
from absl import logging
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale=None):
loss_scale=None,
use_graph_rewrite=None):
"""Configures optimizer object with performance options."""
if use_graph_rewrite is not None:
logging.warning('`use_graph_rewrite` is deprecated inside '
'`configure_optimizer`. Please remove the usage.')
del use_graph_rewrite
if use_float16:
if loss_scale in (None, 'dynamic'):
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
......@@ -29,13 +34,6 @@ def configure_optimizer(optimizer,
# loss_scale is a number. We interpret that as a fixed loss scale.
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
optimizer, dynamic=False, initial_scale=loss_scale)
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.keras.mixed_precision and enable_mixed_precision_graph_rewrite do not
# double up.
optimizer = (
tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite(
optimizer))
return optimizer
......
......@@ -110,6 +110,8 @@ def get_activation(identifier, use_keras_layer=False):
"swish": "swish",
"sigmoid": "sigmoid",
"relu6": tf.nn.relu6,
"hard_swish": activations.hard_swish,
"hard_sigmoid": activations.hard_sigmoid,
}
if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
......
# TF-NLP Model Garden
## Introduction
This TF-NLP library provides a collection of scripts for the training and
evaluation of transformer-based models, on various tasks such as sentence
classification, question answering, and translation. Additionally, we provide
checkpoints of pretrained models which can be finetuned on downstream tasks.
### How to Train Models
Model Garden can be easily installed using PIP
(`pip install tf-models-nightly`). After installation, check out
[this instruction](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
on how to train models with this codebase.
## Available Tasks
There are two available model configs (we will add more) under
`configs/experiments/`:
| Dataset | Task | Config | Example command |
| ----------------- | ------------------------ | ------- | ---- |
| GLUE/MNLI-matched | bert/sentence_prediction | [glue_mnli_matched.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/glue_mnli_matched.yaml) | <details> <summary>finetune BERT-base on this task</summary> PARAMS=runtime.distribution_strategy=mirrored<br/>PARAMS=${PARAMS},task.train_data.input_path=/path-to-your-training-data/<br/>PARAMS=${PARAMS},task.hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4<br/><br/>python3 train.py \\<br/> --experiment=bert/sentence_prediction \\<br/> --mode=train \\<br/> --model_dir=/a-folder-to-hold-checkpoints-and-logs/ \\<br/> --config_file=configs/models/bert_en_uncased_base.yaml \\<br/> --config_file=configs/experiments/glue_mnli_matched.yaml \\<br/> --params_override=${PARAMS}</details> |
| SQuAD v1.1 | bert/squad | [squad_v1.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/squad_v1.yaml) | <details> <summary>finetune BERT-base on this task</summary> PARAMS=runtime.distribution_strategy=mirrored<br/>PARAMS=${PARAMS},task.train_data.input_path=/path-to-your-training-data/<br/>PARAMS=${PARAMS},task.hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4<br/><br/>python3 train.py \\<br/> --experiment=bert/squad \\<br/> --mode=train \\<br/> --model_dir=/a-folder-to-hold-checkpoints-and-logs/ \\<br/> --config_file=configs/models/bert_en_uncased_base.yaml \\<br/> --config_file=configs/experiments/squad_v1.yaml \\<br/> --params_override=${PARAMS}</details> |
One example on how to use the config file: if you want to work on the SQuAD
question answering task, set
`--config_file=configs/experiments/squad_v1.yaml` and
`--experiment=bert/squad`
as arguments to `train.py`.
## Available Model Configs
There are two available model configs (we will add more) under
`configs/models/`:
| Model | Config | Pretrained checkpoint & Vocabulary | TF-HUB SavedModel | Example command |
| ------------ | ------- | ---------------------------------- | ----------------- | --------------- |
| BERT-base | [bert_en_uncased_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/bert_en_uncased_base.yaml) | [uncased_L-12_H-768_A-12](https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz) | [uncased_L-12_H-768_A-12](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/) | <details> <summary>finetune on SQuAD v1.1</summary> PARAMS=runtime.distribution_strategy=mirrored<br/>PARAMS=${PARAMS},task.train_data.input_path=/path-to-your-training-data/<br/>PARAMS=${PARAMS},task.hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4<br/><br/>python3 train.py \\<br/> --experiment=bert/squad \\<br/> --mode=train \\<br/> --model_dir=/a-folder-to-hold-checkpoints-and-logs/ \\<br/> --config_file=configs/models/bert_en_uncased_base.yaml \\<br/> --config_file=configs/experiments/squad_v1.yaml \\<br/> --params_override=${PARAMS}</details> |
| ALBERT-base | [albert_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/albert_base.yaml) | [albert_en_base](https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz) | [albert_en_base](https://tfhub.dev/tensorflow/albert_en_base/3) | <details> <summary>finetune on SQuAD v1.1</summary> PARAMS=runtime.distribution_strategy=mirrored<br/>PARAMS=${PARAMS},task.train_data.input_path=/path-to-your-training-data/<br/>PARAMS=${PARAMS},task.hub_module_url=https://tfhub.dev/tensorflow/albert_en_base/3<br/><br/>python3 train.py \\<br/> --experiment=bert/squad \\<br/> --mode=train \\<br/> --model_dir=/a-folder-to-hold-checkpoints-and-logs/ \\<br/> --config_file=configs/models/albert_base.yaml \\<br/> --config_file=configs/experiments/squad_v1.yaml \\<br/> --params_override=${PARAMS}</details> |
One example on how to use the config file: if you want to train an ALBERT-base
model, set `--config_file=configs/models/albert_base.yaml` as an argument to
`train.py`.
## Useful links
[How to Train Models](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
[List of Pretrained Models](https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md)
[How to Publish Models](https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md)
[TensorFlow blog on Model Garden](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html).
# TensorFlow NLP Modelling Toolkit
This codebase provides a Natrual Language Processing modeling toolkit written in
[TF2](https://www.tensorflow.org/guide/effective_tf2). It allows researchers and
developers to reproduce state-of-the-art model results and train custom models
to experiment new research ideas.
## Features
* Reusable and modularized modeling building blocks
* State-of-the-art reproducible
* Easy to customize and extend
* End-to-end training
* Distributed trainable on both GPUs and TPUs
## Major components
### Libraries
We provide modeling library to allow users to train custom models for new
research ideas. Detailed intructions can be found in READMEs in each folder.
* [modeling/](modeling): modeling library that provides building blocks
(e.g.,Layers, Networks, and Models) that can be assembled into
transformer-based achitectures .
* [data/](data): binaries and utils for input preprocessing, tokenization,
etc.
### State-of-the-Art models and examples
We provide SoTA model implementations, pre-trained models, training and
evaluation examples, and command lines. Detail instructions can be found in the
READMEs for specific papers.
1. [BERT](MODEL_GARDEN.md#available-model-configs): [BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding](https://arxiv.org/abs/1810.04805) by Devlin et al.,
2018
2. [ALBERT](MODEL_GARDEN.md#available-model-configs):
[A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942)
by Lan et al., 2019
3. [XLNet](xlnet):
[XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
by Yang et al., 2019
4. [Transformer for translation](transformer):
[Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et
al., 2017
### Common Training Driver
We provide a single common driver [train.py](train.py) to train above SoTA
models on popluar tasks. Please see [docs/train.md](docs/train.md) for
more details.
### Pre-trained models with checkpoints and TF-Hub
We provide a large collection of baselines and checkpoints for NLP pre-trained
models. Please see [docs/pretrained_models.md](docs/pretrained_models.md) for
more details.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# BERT (Bidirectional Encoder Representations from Transformers)
**WARNING**: We are on the way to deprecate most of the code in this directory.
Please see
[this link](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
for the new tutorial and use the new code in `nlp/modeling`. This README is
still correct for this legacy implementation.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.x implementation for BERT.
## Contents
* [Contents](#contents)
* [Pre-trained Models](#pre-trained-models)
* [Restoring from Checkpoints](#restoring-from-checkpoints)
* [Set Up](#set-up)
* [Process Datasets](#process-datasets)
* [Fine-tuning with BERT](#fine-tuning-with-bert)
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
* [SQuAD 1.1](#squad-1.1)
## Pre-trained Models
We released both checkpoints and tf.hub modules as the pretrained models for
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
released in TF 1.x official BERT repository
[google-research/bert](https://github.com/google-research/bert)
in order to keep consistent with BERT paper.
### Access to Pretrained Checkpoints
Pretrained checkpoints can be found in the following links:
**Note: We have switched BERT implementation
to use Keras functional-style networks in [nlp/modeling](../modeling).
The new checkpoints are:**
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
We recommend to host checkpoints on Google Cloud storage buckets when you use
Cloud GPU/TPU.
### Restoring from Checkpoints
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
weights from provided pre-trained checkpoints, you can use the following code:
```python
init_checkpoint='the pretrained model checkpoint path.'
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
```
Checkpoints featuring native serialized Keras models
(i.e. model.load()/load_weights()) will be available soon.
### Access to Pretrained hub modules.
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
110M parameters
## Set Up
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Install `tf-nightly` to get latest updates:
```shell
pip install tf-nightly-gpu
```
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
```shell
ctpu up -name <instance name> --tf-version=”nightly”
```
Second, you need to install TF 2 `tf-nightly` on your VM:
```shell
pip install tf-nightly
```
## Process Datasets
### Pre-training
There is no change to generate pre-training data. Please use the script
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
Example shell script to call create_pretraining_data.py
```
export WORKING_DIR='local disk or cloud location'
export BERT_DIR='local disk or cloud location'
python models/official/nlp/data/create_pretraining_data.py \
--input_file=$WORKING_DIR/input/input.txt \
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
```
### Fine-tuning
To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in following sections:
* GLUE
Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
```shell
export GLUE_DIR=~/glue
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME}
```
* SQUAD
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation.
The necessary files can be found here:
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
```shell
export SQUAD_DIR=~/squad
export SQUAD_VERSION=v1.1
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384
```
Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
## Fine-tuning with BERT
### Cloud GPUs and TPUs
* Cloud Storage
The unzipped pre-trained model files can also be found in the Google Cloud
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
```
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
script should run with `tf-nightly`.
* GPU -> TPU
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### Sentence and Sentence-pair Classification Tasks
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
few minutes on most GPUs.
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--eval_batch_size=4 \
--steps_per_loop=1 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Alternatively, instead of specifying `init_checkpoint`, you can specify
`hub_module_url` to employ a pretraind BERT hub module, e.g.,
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
After training a model, to get predictions from the classifier, you can set the
`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
Output will be created in file called test_results.tsv in the output folder.
Each line will contain output for each sample, columns are the class
probabilities.
```shell
python run_classifier.py \
--mode='predict' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--eval_batch_size=4 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
To use TPU, you only need to switch distribution strategy type to `tpu` with TPU
information and use remote storage for model checkpoints.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. See more in [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export SQUAD_DIR=gs://some_bucket/datasets
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Similarily, you can replace `init_checkpoint` FLAG with `hub_module_url` to
specify a hub module path.
`run_squad.py` writes the prediction for `--predict_file` by default. If you set
the `--model=predict` and offer the SQuAD test data, the scripts will generate
the prediction json file.
To use TPU, you need switch distribution strategy type to `tpu` with TPU
information.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
The dev set predictions will be saved into a file called predictions.json in the
model_dir:
```shell
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment