Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -71,10 +71,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def create_optimizer(params: params_dict.ParamsDict):
"""Creates optimizer."""
lr_schedule = LearningRateSchedule(
params.learning_rate,
params.hidden_size,
params.learning_rate_warmup_steps)
lr_schedule = LearningRateSchedule(params.learning_rate, params.hidden_size,
params.learning_rate_warmup_steps)
return tf.keras.optimizers.Adam(
learning_rate=lr_schedule,
beta_1=params.adam_beta1,
......
......@@ -16,6 +16,7 @@
"""Processes crawled content from news URLs by generating tfrecords."""
import os
from absl import app
from absl import flags
from official.nlp.nhnet import raw_data_processor
......
......@@ -20,6 +20,7 @@ import json
import multiprocessing
import os
import urllib.parse
import tensorflow as tf
from official.nlp.bert import tokenization
......@@ -47,10 +48,10 @@ class RawDataProcessor(object):
max_num_articles: Maximum number of articles in a story.
include_article_title_in_passage: Whether to include article title in
article passage.
include_text_snippet_in_example: Whether to include text snippet
(headline and article content) in generated tensorflow Examples, for
debug usage. If include_article_title_in_passage=True, title and body
will be separated by [SEP].
include_text_snippet_in_example: Whether to include text snippet (headline
and article content) in generated tensorflow Examples, for debug usage.
If include_article_title_in_passage=True, title and body will be
separated by [SEP].
"""
self.articles = dict()
self.tokenizer = tokenization.FullTokenizer(
......@@ -156,6 +157,7 @@ class RawDataProcessor(object):
def _get_single_story_features(self, story_headline, articles):
"""Converts a list of articles to a tensorflow Example."""
def get_text_snippet(article):
if article.text_b:
return " [SEP] ".join([article.text_a, article.text_b])
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -44,6 +44,8 @@ def encoder_common_layers(transformer_block):
transformer_block._intermediate_dense, transformer_block._output_dense,
transformer_block._output_layer_norm
]
# pylint: enable=protected-access
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import dataclasses
import tensorflow as tf
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Masked language task."""
import dataclasses
import tensorflow as tf
......
......@@ -52,8 +52,7 @@ class MLMTaskTest(tf.test.TestCase):
task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint.
ckpt = tf.train.Checkpoint(
model=model, **model.checkpoint_items)
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
ckpt.save(config.init_checkpoint)
task.initialize(model)
......
......@@ -111,9 +111,7 @@ class QuestionAnsweringTask(base_task.Task):
tf.cast(start_logits, dtype=tf.float32),
from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions,
tf.cast(end_logits, dtype=tf.float32),
from_logits=True)
end_positions, tf.cast(end_logits, dtype=tf.float32), from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return loss
......@@ -142,8 +140,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs = dict(
examples=eval_examples,
tokenizer=tokenization.FullTokenizer(
vocab_file=params.vocab_file,
do_lower_case=params.do_lower_case),
vocab_file=params.vocab_file, do_lower_case=params.do_lower_case),
max_seq_length=params.seq_length,
doc_stride=params.doc_stride,
max_query_length=params.query_length,
......@@ -192,8 +189,8 @@ class QuestionAnsweringTask(base_task.Task):
input_path = self._tf_record_input_path
dataloader_params = params.replace(input_path=input_path)
return data_loader_factory.get_data_loader(
dataloader_params).load(input_context)
return data_loader_factory.get_data_loader(dataloader_params).load(
input_context)
def build_metrics(self, training=None):
del training
......@@ -209,16 +206,19 @@ class QuestionAnsweringTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
start_logits, end_logits = model_outputs
metrics['start_position_accuracy'].update_state(
labels['start_positions'], start_logits)
metrics['end_position_accuracy'].update_state(
labels['end_positions'], end_logits)
metrics['start_position_accuracy'].update_state(labels['start_positions'],
start_logits)
metrics['end_position_accuracy'].update_state(labels['end_positions'],
end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits, end_logits = model_outputs
compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={'start_positions': start_logits, 'end_positions': end_logits})
y_pred={
'start_positions': start_logits,
'end_positions': end_logits
})
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs
......@@ -242,16 +242,16 @@ class QuestionAnsweringTask(base_task.Task):
state = []
for unique_ids, start_logits, end_logits in zip(
step_outputs['unique_ids'],
step_outputs['start_logits'],
step_outputs['unique_ids'], step_outputs['start_logits'],
step_outputs['end_logits']):
u_ids, s_logits, e_logits = (
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy())
u_ids, s_logits, e_logits = (unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy())
for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist()))
state.append(
self.raw_aggregated_result(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist()))
return state
def reduce_aggregated_logs(self, aggregated_logs):
......@@ -269,13 +269,13 @@ class QuestionAnsweringTask(base_task.Task):
self.task_config.null_score_diff_threshold),
verbose=False))
with tf.io.gfile.GFile(
self.task_config.validation_data.input_path, 'r') as reader:
with tf.io.gfile.GFile(self.task_config.validation_data.input_path,
'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate(
pred_dataset, all_predictions, scores_diff)
eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions,
scores_diff)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {
......@@ -284,13 +284,16 @@ class QuestionAnsweringTask(base_task.Task):
'final_f1': eval_metrics['final_f1'] / 100.0, # scale back to [0, 1].
'f1_threshold': eval_metrics['final_f1_thresh'],
'has_answer_exact_match': eval_metrics['HasAns_exact'],
'has_answer_f1': eval_metrics['HasAns_f1']}
'has_answer_f1': eval_metrics['HasAns_f1']
}
else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
# Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute.
eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']}
eval_metrics = {
'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']
}
return eval_metrics
......
......@@ -17,6 +17,7 @@
import itertools
import json
import os
from absl.testing import parameterized
import tensorflow as tf
......
......@@ -35,7 +35,6 @@ from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
METRIC_TYPES = frozenset(
['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])
......@@ -137,7 +136,8 @@ class SentencePredictionTask(base_task.Task):
metrics = [tf.keras.metrics.MeanSquaredError()]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
......
......@@ -250,8 +250,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
cur_predict_ids = state['predict_ids']
cur_sentence_ids = state['sentence_ids']
for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip(
outputs['predict_ids'], outputs['label_mask'],
outputs['sentence_ids']):
outputs['predict_ids'], outputs['label_mask'], outputs['sentence_ids']):
for tmp_predict_ids, tmp_label_mask, tmp_sentence_id in zip(
batch_predict_ids.numpy(), batch_label_mask.numpy(),
batch_sentence_ids.numpy()):
......
......@@ -16,6 +16,7 @@
"""Tests for official.nlp.tasks.tagging."""
import functools
import os
import numpy as np
import tensorflow as tf
......
......@@ -38,15 +38,14 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
def predict(predict_step_fn: Callable[[Any], Any],
aggregate_fn: Callable[[Any, Any], Any],
dataset: tf.data.Dataset):
aggregate_fn: Callable[[Any, Any], Any], dataset: tf.data.Dataset):
"""Runs prediction.
Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`.
`value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object.
Returns:
......
......@@ -88,7 +88,12 @@ class Attention(tf.keras.layers.Layer):
"attention_dropout": self.attention_dropout,
}
def call(self, query_input, source_input, bias, training, cache=None,
def call(self,
query_input,
source_input,
bias,
training,
cache=None,
decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input.
......@@ -102,9 +107,9 @@ class Attention(tf.keras.layers.Layer):
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]}
where i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
"v": tensor with shape [batch_size, i, heads, dim_per_head]} where
i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
......@@ -142,7 +147,7 @@ class Attention(tf.keras.layers.Layer):
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
query *= depth ** -0.5
query *= depth**-0.5
# Calculate dot product attention
logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
......@@ -164,7 +169,11 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, query_input, bias, training, cache=None,
def call(self,
query_input,
bias,
training,
cache=None,
decode_loop_step=None):
return super(SelfAttention, self).call(
query_input, query_input, bias, training, cache, decode_loop_step)
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)
......@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Beam search to find the translated sequence with the highest probability.
"""
"""Beam search to find the translated sequence with the highest probability."""
import tensorflow.compat.v1 as tf
from official.nlp.modeling.ops import beam_search
......@@ -41,23 +40,27 @@ class SequenceBeamSearch(beam_search.SequenceBeamSearch):
return finished_seq, finished_scores
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id, padded_decode=False):
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
arguments. The passed in arguments will have shape: ids -> A tensor with
shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache: logits -> A
tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested
dictionary with the same shape/structure as the inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each
batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of the vocabulary, used for topk
......@@ -67,8 +70,8 @@ def sequence_beam_search(
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
padded_decode: A bool, indicating if max_sequence_length padding is used for
beam search.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
......
......@@ -23,6 +23,7 @@ import random
import tarfile
# pylint: disable=g-bad-import-order
from absl import app
from absl import flags
from absl import logging
......@@ -64,22 +65,18 @@ _TRAIN_DATA_SOURCES = [
# Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT = 6
_EVAL_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"input": "newstest2013.en",
"target": "newstest2013.de",
}
]
_EVAL_DATA_SOURCES = [{
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"input": "newstest2013.en",
"target": "newstest2013.de",
}]
_TEST_DATA_SOURCES = [
{
"url": ("https://storage.googleapis.com/tf-perf-public/"
"official_transformer/test_data/newstest2014.tgz"),
"input": "newstest2014.en",
"target": "newstest2014.de",
}
]
_TEST_DATA_SOURCES = [{
"url": ("https://storage.googleapis.com/tf-perf-public/"
"official_transformer/test_data/newstest2014.tgz"),
"input": "newstest2014.en",
"target": "newstest2014.de",
}]
# Vocabulary constants
_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
......@@ -114,7 +111,9 @@ def find_file(path, filename, max_depth=5):
# Download and extraction functions
###############################################################################
def get_raw_files(raw_dir, data_source):
"""Return raw files from source. Downloads/extracts if needed.
"""Return raw files from source.
Downloads/extracts if needed.
Args:
raw_dir: string directory to store raw files
......@@ -134,8 +133,8 @@ def get_raw_files(raw_dir, data_source):
"targets": [],
} # keys
for d in data_source:
input_file, target_file = download_and_extract(
raw_dir, d["url"], d["input"], d["target"])
input_file, target_file = download_and_extract(raw_dir, d["url"],
d["input"], d["target"])
raw_files["inputs"].append(input_file)
raw_files["targets"].append(target_file)
return raw_files
......@@ -167,7 +166,7 @@ def download_from_url(path, url):
found_file = find_file(path, filename, max_depth=0)
if found_file is None:
filename = os.path.join(path, filename)
logging.info("Downloading from %s to %s." % (url, filename))
logging.info("Downloading from %s to %s.", url, filename)
inprogress_filepath = six.ensure_str(filename) + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
......@@ -176,7 +175,7 @@ def download_from_url(path, url):
tf.gfile.Rename(inprogress_filepath, filename)
return filename
else:
logging.info("Already downloaded: %s (at %s)." % (url, found_file))
logging.info("Already downloaded: %s (at %s).", url, found_file)
return found_file
......@@ -199,14 +198,14 @@ def download_and_extract(path, url, input_filename, target_filename):
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
logging.info("Already downloaded and extracted %s." % url)
logging.info("Already downloaded and extracted %s.", url)
return input_file, target_file
# Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url)
# Extract compressed files
logging.info("Extracting %s." % compressed_file)
logging.info("Extracting %s.", compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path)
......@@ -236,13 +235,13 @@ def compile_files(raw_dir, raw_files, tag):
raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
}
tag: String to append to the compiled filename.
Returns:
Full path of compiled input and target files.
"""
logging.info("Compiling files with tag %s." % tag)
logging.info("Compiling files with tag %s.", tag)
filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang1")
......@@ -255,7 +254,7 @@ def compile_files(raw_dir, raw_files, tag):
input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i]
logging.info("Reading files %s and %s." % (input_file, target_file))
logging.info("Reading files %s and %s.", input_file, target_file)
write_file(input_writer, input_file)
write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file
......@@ -271,8 +270,7 @@ def write_file(writer, filename):
###############################################################################
# Data preprocessing
###############################################################################
def encode_and_save_files(
subtokenizer, data_dir, raw_files, tag, total_shards):
def encode_and_save_files(subtokenizer, data_dir, raw_files, tag, total_shards):
"""Save data from files as encoded Examples in TFrecord format.
Args:
......@@ -287,14 +285,16 @@ def encode_and_save_files(
List of all files produced.
"""
# Create a file for each shard.
filepaths = [shard_filename(data_dir, tag, n + 1, total_shards)
for n in range(total_shards)]
filepaths = [
shard_filename(data_dir, tag, n + 1, total_shards)
for n in range(total_shards)
]
if all_exist(filepaths):
logging.info("Files with tag %s already exist." % tag)
logging.info("Files with tag %s already exist.", tag)
return filepaths
logging.info("Saving files with tag %s." % tag)
logging.info("Saving files with tag %s.", tag)
input_file = raw_files[0]
target_file = raw_files[1]
......@@ -302,13 +302,14 @@ def encode_and_save_files(
tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(zip(
txt_line_iterator(input_file), txt_line_iterator(target_file))):
for counter, (input_line, target_line) in enumerate(
zip(txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0:
logging.info("\tSaving case %d." % counter)
example = dict_to_example(
{"inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)})
logging.info("\tSaving case %d.", counter)
example = dict_to_example({
"inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)
})
writers[shard].write(example.SerializeToString())
shard = (shard + 1) % total_shards
for writer in writers:
......@@ -329,7 +330,7 @@ def shard_filename(path, tag, shard_num, total_shards):
def shuffle_records(fname):
"""Shuffle records in a single file."""
logging.info("Shuffling records in file %s" % fname)
logging.info("Shuffling records in file %s", fname)
# Rename file prior to shuffling
tmp_fname = six.ensure_str(fname) + ".unshuffled"
......@@ -349,7 +350,7 @@ def shuffle_records(fname):
for count, record in enumerate(records):
w.write(record)
if count > 0 and count % 100000 == 0:
logging.info("\tWriting record: %d" % count)
logging.info("\tWriting record: %d", count)
tf.gfile.Remove(tmp_fname)
......@@ -372,7 +373,7 @@ def all_exist(filepaths):
def make_dir(path):
if not tf.gfile.Exists(path):
logging.info("Creating directory %s" % path)
logging.info("Creating directory %s", path)
tf.gfile.MakeDirs(path)
......@@ -395,7 +396,10 @@ def main(unused_argv):
train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_files(
vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD,
vocab_file,
train_files_flat,
_TARGET_VOCAB_SIZE,
_TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
logging.info("Step 4/5: Compiling training and evaluation data")
......@@ -404,12 +408,11 @@ def main(unused_argv):
# Tokenize and save data as Examples in the TFRecord format.
logging.info("Step 5/5: Preprocessing and saving data")
train_tfrecord_files = encode_and_save_files(
subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS)
encode_and_save_files(
subtokenizer, FLAGS.data_dir, compiled_eval_files, _EVAL_TAG,
_EVAL_SHARDS)
train_tfrecord_files = encode_and_save_files(subtokenizer, FLAGS.data_dir,
compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS)
encode_and_save_files(subtokenizer, FLAGS.data_dir, compiled_eval_files,
_EVAL_TAG, _EVAL_SHARDS)
for fname in train_tfrecord_files:
shuffle_records(fname)
......@@ -418,15 +421,20 @@ def main(unused_argv):
def define_data_download_flags():
"""Add flags specifying data download arguments."""
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende",
name="data_dir",
short_name="dd",
default="/tmp/translate_ende",
help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw",
name="raw_dir",
short_name="rd",
default="/tmp/translate_ende_raw",
help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool(
name="search", default=False,
name="search",
default=False,
help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
......
......@@ -87,8 +87,9 @@ def _parse_example(serialized_example):
def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
return tf.logical_and(
tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example):
......@@ -97,8 +98,9 @@ def _get_example_length(example):
return length
def _create_min_max_boundaries(
max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE):
def _create_min_max_boundaries(max_length,
min_boundary=_MIN_BOUNDARY,
boundary_scale=_BOUNDARY_SCALE):
"""Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
......@@ -165,8 +167,8 @@ def _batch_examples(dataset, batch_size, max_length):
# TODO(xunkai): investigate if removing code branching improves performance.
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length),
tf.less(seq_length, buckets_max))
tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
......@@ -183,16 +185,23 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.data.experimental.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
window_size_func=window_size_fn))
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, max_io_parallelism, shuffle, repeat,
static_batch=False, num_replicas=1, ctx=None):
return dataset.apply(
tf.data.experimental.group_by_window(
key_func=example_to_bucket_id,
reduce_func=batching_fn,
window_size=None,
window_size_func=window_size_fn))
def _read_and_batch_from_files(file_pattern,
batch_size,
max_length,
max_io_parallelism,
shuffle,
repeat,
static_batch=False,
num_replicas=1,
ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
......@@ -204,20 +213,18 @@ def _read_and_batch_from_files(
repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever.
static_batch: Whether the batches in the dataset should have static shapes.
If True, the input is batched so that every batch has the
shape [batch_size // max_length, max_length]. If False, the input is
grouped by length, and batched so that batches may have different
shapes [N, M], where:
N * M <= batch_size
M <= max_length
In general, this setting should be False. Dynamic shapes allow the inputs
to be grouped so that the number of padding tokens is minimized, and helps
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
If True, the input is batched so that every batch has the shape
[batch_size // max_length, max_length]. If False, the input is grouped by
length, and batched so that batches may have different
shapes [N, M], where: N * M <= batch_size M <= max_length In general, this
setting should be False. Dynamic shapes allow the inputs to be grouped
so that the number of padding tokens is minimized, and helps model
training. In cases where the input shape must be static (e.g. running on
TPU), this setting should be set to True.
num_replicas: Number of GPUs or other workers. We will generate global
batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False.
effective when static_batch=False.
ctx: Input context.
Returns:
......@@ -240,8 +247,8 @@ def _read_and_batch_from_files(
# Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(
_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
......@@ -252,7 +259,8 @@ def _read_and_batch_from_files(
# into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy.
int(batch_size // num_replicas // max_length * num_replicas),
([max_length], [max_length]), drop_remainder=True)
([max_length], [max_length]),
drop_remainder=True)
else:
# Group and batch such that each batch has examples of similar length.
# TODO(xunkai): _batch_examples might need to do something special for
......@@ -291,10 +299,15 @@ def train_input_fn(params, ctx=None):
if params["use_synthetic_data"]:
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["max_io_parallelism"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"], ctx=ctx)
file_pattern,
params["batch_size"],
params["max_length"],
params["max_io_parallelism"],
shuffle=True,
repeat=params["repeat_dataset"],
static_batch=params["static_batch"],
num_replicas=params["num_gpus"],
ctx=ctx)
def eval_input_fn(params, ctx=None):
......@@ -303,9 +316,14 @@ def eval_input_fn(params, ctx=None):
if params["use_synthetic_data"]:
return _generate_synthetic_data(params)
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["max_io_parallelism"], shuffle=False, repeat=1,
static_batch=params["static_batch"], num_replicas=params["num_gpus"],
file_pattern,
params["batch_size"],
params["max_length"],
params["max_io_parallelism"],
shuffle=False,
repeat=1,
static_batch=params["static_batch"],
num_replicas=params["num_gpus"],
ctx=ctx)
......
......@@ -60,6 +60,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args:
inputs: An int64 tensor with shape [batch_size, length]
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
......@@ -82,7 +83,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
embeddings *= tf.expand_dims(mask, -1)
# Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5
embeddings *= self.hidden_size**0.5
return embeddings
......@@ -91,6 +92,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf
......@@ -66,28 +67,34 @@ def define_transformer_flags():
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True
)
fp16_implementation=True)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
flags.DEFINE_integer(
name='train_steps', short_name='ts', default=300000,
name='train_steps',
short_name='ts',
default=300000,
help=flags_core.help_wrap('The number of steps used to train.'))
flags.DEFINE_integer(
name='steps_between_evals', short_name='sbe', default=5000,
name='steps_between_evals',
short_name='sbe',
default=5000,
help=flags_core.help_wrap(
'The Number of training steps to run between evaluations. This is '
'used if --train_steps is defined.'))
flags.DEFINE_boolean(
name='enable_time_history', default=True,
name='enable_time_history',
default=True,
help='Whether to enable TimeHistory callback.')
flags.DEFINE_boolean(
name='enable_tensorboard', default=False,
name='enable_tensorboard',
default=False,
help='Whether to enable Tensorboard callback.')
flags.DEFINE_boolean(
name='enable_metrics_in_training', default=False,
name='enable_metrics_in_training',
default=False,
help='Whether to enable metrics during training.')
flags.DEFINE_boolean(
name='enable_mlir_bridge',
......@@ -100,7 +107,9 @@ def define_transformer_flags():
# Add transformer-specific flags
flags.DEFINE_enum(
name='param_set', short_name='mp', default='big',
name='param_set',
short_name='mp',
default='big',
enum_values=PARAMS_MAP.keys(),
help=flags_core.help_wrap(
'Parameter set to use when creating and training the model. The '
......@@ -111,7 +120,9 @@ def define_transformer_flags():
'complete list of parameters, please see model/model_params.py.'))
flags.DEFINE_bool(
name='static_batch', short_name='sb', default=False,
name='static_batch',
short_name='sb',
default=False,
help=flags_core.help_wrap(
'Whether the batches in the dataset should have static shapes. In '
'general, this setting should be False. Dynamic shapes allow the '
......@@ -120,7 +131,9 @@ def define_transformer_flags():
'must be static (e.g. running on TPU), this setting will be ignored '
'and static batching will always be used.'))
flags.DEFINE_integer(
name='max_length', short_name='ml', default=256,
name='max_length',
short_name='ml',
default=256,
help=flags_core.help_wrap(
'Max sentence length for Transformer. Default is 256. Note: Usually '
'it is more effective to use a smaller max length if static_batch is '
......@@ -128,30 +141,39 @@ def define_transformer_flags():
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
name='validation_steps', short_name='vs', default=64,
name='validation_steps',
short_name='vs',
default=64,
help=flags_core.help_wrap('The number of steps used in validation.'))
# BLEU score computation
flags.DEFINE_string(
name='bleu_source', short_name='bls', default=None,
name='bleu_source',
short_name='bls',
default=None,
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
))
flags.DEFINE_string(
name='bleu_ref', short_name='blr', default=None,
name='bleu_ref',
short_name='blr',
default=None,
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
))
flags.DEFINE_string(
name='vocab_file', short_name='vf', default=None,
name='vocab_file',
short_name='vf',
default=None,
help=flags_core.help_wrap(
'Path to subtoken vocabulary file. If data_download.py was used to '
'download and encode the training data, look in the data_dir to find '
'the vocab file.'))
flags.DEFINE_string(
name='mode', default='train',
name='mode',
default='train',
help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
......@@ -188,9 +210,10 @@ def define_transformer_flags():
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
batch_size=None)
flags_core.set_defaults(
data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
batch_size=None)
# pylint: disable=unused-variable
@flags.multi_flags_validator(
......@@ -203,11 +226,12 @@ def define_transformer_flags():
@flags.multi_flags_validator(
['bleu_source', 'bleu_ref', 'vocab_file'],
message='--vocab_file must be defined if --bleu_source and --bleu_ref '
'are defined.')
'are defined.')
def _check_bleu_vocab_file(flags_dict):
if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
return flags_dict['vocab_file'] is not None
return True
# pylint: enable=unused-variable
......@@ -256,5 +280,5 @@ def update_stats(history, stats, callbacks):
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) /
(len(callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment