Commit 999fae62 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 94561082
...@@ -71,10 +71,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -71,10 +71,8 @@ class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def create_optimizer(params: params_dict.ParamsDict): def create_optimizer(params: params_dict.ParamsDict):
"""Creates optimizer.""" """Creates optimizer."""
lr_schedule = LearningRateSchedule( lr_schedule = LearningRateSchedule(params.learning_rate, params.hidden_size,
params.learning_rate, params.learning_rate_warmup_steps)
params.hidden_size,
params.learning_rate_warmup_steps)
return tf.keras.optimizers.Adam( return tf.keras.optimizers.Adam(
learning_rate=lr_schedule, learning_rate=lr_schedule,
beta_1=params.adam_beta1, beta_1=params.adam_beta1,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Processes crawled content from news URLs by generating tfrecords.""" """Processes crawled content from news URLs by generating tfrecords."""
import os import os
from absl import app from absl import app
from absl import flags from absl import flags
from official.nlp.nhnet import raw_data_processor from official.nlp.nhnet import raw_data_processor
......
...@@ -20,6 +20,7 @@ import json ...@@ -20,6 +20,7 @@ import json
import multiprocessing import multiprocessing
import os import os
import urllib.parse import urllib.parse
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
...@@ -47,10 +48,10 @@ class RawDataProcessor(object): ...@@ -47,10 +48,10 @@ class RawDataProcessor(object):
max_num_articles: Maximum number of articles in a story. max_num_articles: Maximum number of articles in a story.
include_article_title_in_passage: Whether to include article title in include_article_title_in_passage: Whether to include article title in
article passage. article passage.
include_text_snippet_in_example: Whether to include text snippet include_text_snippet_in_example: Whether to include text snippet (headline
(headline and article content) in generated tensorflow Examples, for and article content) in generated tensorflow Examples, for debug usage.
debug usage. If include_article_title_in_passage=True, title and body If include_article_title_in_passage=True, title and body will be
will be separated by [SEP]. separated by [SEP].
""" """
self.articles = dict() self.articles = dict()
self.tokenizer = tokenization.FullTokenizer( self.tokenizer = tokenization.FullTokenizer(
...@@ -156,6 +157,7 @@ class RawDataProcessor(object): ...@@ -156,6 +157,7 @@ class RawDataProcessor(object):
def _get_single_story_features(self, story_headline, articles): def _get_single_story_features(self, story_headline, articles):
"""Converts a list of articles to a tensorflow Example.""" """Converts a list of articles to a tensorflow Example."""
def get_text_snippet(article): def get_text_snippet(article):
if article.text_b: if article.text_b:
return " [SEP] ".join([article.text_a, article.text_b]) return " [SEP] ".join([article.text_a, article.text_b])
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -44,6 +44,8 @@ def encoder_common_layers(transformer_block): ...@@ -44,6 +44,8 @@ def encoder_common_layers(transformer_block):
transformer_block._intermediate_dense, transformer_block._output_dense, transformer_block._intermediate_dense, transformer_block._output_dense,
transformer_block._output_layer_norm transformer_block._output_layer_norm
] ]
# pylint: enable=protected-access # pylint: enable=protected-access
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection).""" """ELECTRA pretraining task (Joint Masked LM and Replaced Token Detection)."""
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language task.""" """Masked language task."""
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
......
...@@ -52,8 +52,7 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -52,8 +52,7 @@ class MLMTaskTest(tf.test.TestCase):
task.validation_step(next(iterator), model, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
# Saves a checkpoint. # Saves a checkpoint.
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
model=model, **model.checkpoint_items)
ckpt.save(config.init_checkpoint) ckpt.save(config.init_checkpoint)
task.initialize(model) task.initialize(model)
......
...@@ -111,9 +111,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -111,9 +111,7 @@ class QuestionAnsweringTask(base_task.Task):
tf.cast(start_logits, dtype=tf.float32), tf.cast(start_logits, dtype=tf.float32),
from_logits=True) from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy( end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_positions, tf.cast(end_logits, dtype=tf.float32), from_logits=True)
tf.cast(end_logits, dtype=tf.float32),
from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2 loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return loss return loss
...@@ -142,8 +140,7 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -142,8 +140,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs = dict( kwargs = dict(
examples=eval_examples, examples=eval_examples,
tokenizer=tokenization.FullTokenizer( tokenizer=tokenization.FullTokenizer(
vocab_file=params.vocab_file, vocab_file=params.vocab_file, do_lower_case=params.do_lower_case),
do_lower_case=params.do_lower_case),
max_seq_length=params.seq_length, max_seq_length=params.seq_length,
doc_stride=params.doc_stride, doc_stride=params.doc_stride,
max_query_length=params.query_length, max_query_length=params.query_length,
...@@ -192,8 +189,8 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -192,8 +189,8 @@ class QuestionAnsweringTask(base_task.Task):
input_path = self._tf_record_input_path input_path = self._tf_record_input_path
dataloader_params = params.replace(input_path=input_path) dataloader_params = params.replace(input_path=input_path)
return data_loader_factory.get_data_loader( return data_loader_factory.get_data_loader(dataloader_params).load(
dataloader_params).load(input_context) input_context)
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
...@@ -209,16 +206,19 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -209,16 +206,19 @@ class QuestionAnsweringTask(base_task.Task):
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
start_logits, end_logits = model_outputs start_logits, end_logits = model_outputs
metrics['start_position_accuracy'].update_state( metrics['start_position_accuracy'].update_state(labels['start_positions'],
labels['start_positions'], start_logits) start_logits)
metrics['end_position_accuracy'].update_state( metrics['end_position_accuracy'].update_state(labels['end_positions'],
labels['end_positions'], end_logits) end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits, end_logits = model_outputs start_logits, end_logits = model_outputs
compiled_metrics.update_state( compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'. y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={'start_positions': start_logits, 'end_positions': end_logits}) y_pred={
'start_positions': start_logits,
'end_positions': end_logits
})
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs features, _ = inputs
...@@ -242,16 +242,16 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -242,16 +242,16 @@ class QuestionAnsweringTask(base_task.Task):
state = [] state = []
for unique_ids, start_logits, end_logits in zip( for unique_ids, start_logits, end_logits in zip(
step_outputs['unique_ids'], step_outputs['unique_ids'], step_outputs['start_logits'],
step_outputs['start_logits'],
step_outputs['end_logits']): step_outputs['end_logits']):
u_ids, s_logits, e_logits = ( u_ids, s_logits, e_logits = (unique_ids.numpy(), start_logits.numpy(),
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy()) end_logits.numpy())
for values in zip(u_ids, s_logits, e_logits): for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result( state.append(
unique_id=values[0], self.raw_aggregated_result(
start_logits=values[1].tolist(), unique_id=values[0],
end_logits=values[2].tolist())) start_logits=values[1].tolist(),
end_logits=values[2].tolist()))
return state return state
def reduce_aggregated_logs(self, aggregated_logs): def reduce_aggregated_logs(self, aggregated_logs):
...@@ -269,13 +269,13 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -269,13 +269,13 @@ class QuestionAnsweringTask(base_task.Task):
self.task_config.null_score_diff_threshold), self.task_config.null_score_diff_threshold),
verbose=False)) verbose=False))
with tf.io.gfile.GFile( with tf.io.gfile.GFile(self.task_config.validation_data.input_path,
self.task_config.validation_data.input_path, 'r') as reader: 'r') as reader:
dataset_json = json.load(reader) dataset_json = json.load(reader)
pred_dataset = dataset_json['data'] pred_dataset = dataset_json['data']
if self.task_config.validation_data.version_2_with_negative: if self.task_config.validation_data.version_2_with_negative:
eval_metrics = squad_evaluate_v2_0.evaluate( eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions,
pred_dataset, all_predictions, scores_diff) scores_diff)
# Filter out useless metrics, such as start_position_accuracy that # Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute. # we did not actually compute.
eval_metrics = { eval_metrics = {
...@@ -284,13 +284,16 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -284,13 +284,16 @@ class QuestionAnsweringTask(base_task.Task):
'final_f1': eval_metrics['final_f1'] / 100.0, # scale back to [0, 1]. 'final_f1': eval_metrics['final_f1'] / 100.0, # scale back to [0, 1].
'f1_threshold': eval_metrics['final_f1_thresh'], 'f1_threshold': eval_metrics['final_f1_thresh'],
'has_answer_exact_match': eval_metrics['HasAns_exact'], 'has_answer_exact_match': eval_metrics['HasAns_exact'],
'has_answer_f1': eval_metrics['HasAns_f1']} 'has_answer_f1': eval_metrics['HasAns_f1']
}
else: else:
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
# Filter out useless metrics, such as start_position_accuracy that # Filter out useless metrics, such as start_position_accuracy that
# we did not actually compute. # we did not actually compute.
eval_metrics = {'exact_match': eval_metrics['exact_match'], eval_metrics = {
'final_f1': eval_metrics['final_f1']} 'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']
}
return eval_metrics return eval_metrics
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import itertools import itertools
import json import json
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
......
...@@ -35,7 +35,6 @@ from official.nlp.data import data_loader_factory ...@@ -35,7 +35,6 @@ from official.nlp.data import data_loader_factory
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
METRIC_TYPES = frozenset( METRIC_TYPES = frozenset(
['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr']) ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])
...@@ -137,7 +136,8 @@ class SentencePredictionTask(base_task.Task): ...@@ -137,7 +136,8 @@ class SentencePredictionTask(base_task.Task):
metrics = [tf.keras.metrics.MeanSquaredError()] metrics = [tf.keras.metrics.MeanSquaredError()]
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')] tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
......
...@@ -250,8 +250,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig, ...@@ -250,8 +250,7 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
cur_predict_ids = state['predict_ids'] cur_predict_ids = state['predict_ids']
cur_sentence_ids = state['sentence_ids'] cur_sentence_ids = state['sentence_ids']
for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip( for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip(
outputs['predict_ids'], outputs['label_mask'], outputs['predict_ids'], outputs['label_mask'], outputs['sentence_ids']):
outputs['sentence_ids']):
for tmp_predict_ids, tmp_label_mask, tmp_sentence_id in zip( for tmp_predict_ids, tmp_label_mask, tmp_sentence_id in zip(
batch_predict_ids.numpy(), batch_label_mask.numpy(), batch_predict_ids.numpy(), batch_label_mask.numpy(),
batch_sentence_ids.numpy()): batch_sentence_ids.numpy()):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Tests for official.nlp.tasks.tagging.""" """Tests for official.nlp.tasks.tagging."""
import functools import functools
import os import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
......
...@@ -38,15 +38,14 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model: ...@@ -38,15 +38,14 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
def predict(predict_step_fn: Callable[[Any], Any], def predict(predict_step_fn: Callable[[Any], Any],
aggregate_fn: Callable[[Any, Any], Any], aggregate_fn: Callable[[Any, Any], Any], dataset: tf.data.Dataset):
dataset: tf.data.Dataset):
"""Runs prediction. """Runs prediction.
Args: Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors. `inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`. `value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object. dataset: A `tf.data.Dataset` object.
Returns: Returns:
......
...@@ -88,7 +88,12 @@ class Attention(tf.keras.layers.Layer): ...@@ -88,7 +88,12 @@ class Attention(tf.keras.layers.Layer):
"attention_dropout": self.attention_dropout, "attention_dropout": self.attention_dropout,
} }
def call(self, query_input, source_input, bias, training, cache=None, def call(self,
query_input,
source_input,
bias,
training,
cache=None,
decode_loop_step=None): decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input. """Apply attention mechanism to query_input and source_input.
...@@ -102,9 +107,9 @@ class Attention(tf.keras.layers.Layer): ...@@ -102,9 +107,9 @@ class Attention(tf.keras.layers.Layer):
cache: (Used during prediction) A dictionary with tensors containing cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items: results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head], {"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]} "v": tensor with shape [batch_size, i, heads, dim_per_head]} where
where i is the current decoded length for non-padded decode, or max i is the current decoded length for non-padded decode, or max
sequence length for padded decode. sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU. for autoregressive inference on TPU.
...@@ -142,7 +147,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -142,7 +147,7 @@ class Attention(tf.keras.layers.Layer):
# Scale query to prevent the dot product between query and key from growing # Scale query to prevent the dot product between query and key from growing
# too large. # too large.
depth = (self.hidden_size // self.num_heads) depth = (self.hidden_size // self.num_heads)
query *= depth ** -0.5 query *= depth**-0.5
# Calculate dot product attention # Calculate dot product attention
logits = tf.einsum("BTNH,BFNH->BNFT", key, query) logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
...@@ -164,7 +169,11 @@ class Attention(tf.keras.layers.Layer): ...@@ -164,7 +169,11 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention): class SelfAttention(Attention):
"""Multiheaded self-attention layer.""" """Multiheaded self-attention layer."""
def call(self, query_input, bias, training, cache=None, def call(self,
query_input,
bias,
training,
cache=None,
decode_loop_step=None): decode_loop_step=None):
return super(SelfAttention, self).call( return super(SelfAttention, self).call(query_input, query_input, bias,
query_input, query_input, bias, training, cache, decode_loop_step) training, cache, decode_loop_step)
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Beam search to find the translated sequence with the highest probability. """Beam search to find the translated sequence with the highest probability."""
"""
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from official.nlp.modeling.ops import beam_search from official.nlp.modeling.ops import beam_search
...@@ -41,23 +40,27 @@ class SequenceBeamSearch(beam_search.SequenceBeamSearch): ...@@ -41,23 +40,27 @@ class SequenceBeamSearch(beam_search.SequenceBeamSearch):
return finished_seq, finished_scores return finished_seq, finished_scores
def sequence_beam_search( def sequence_beam_search(symbols_to_logits_fn,
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, initial_ids,
alpha, max_decode_length, eos_id, padded_decode=False): initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape: ids -> A tensor with
ids -> A tensor with shape [batch_size * beam_size, index]. shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
index -> A scalar. nested dictionary of tensors [batch_size * beam_size, ...].
cache -> A nested dictionary of tensors [batch_size * beam_size, ...]. The function must return a tuple of logits and new cache: logits -> A
The function must return a tuple of logits and new cache: tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested
logits -> A tensor with shape [batch * beam_size, vocab_size]. dictionary with the same shape/structure as the inputted cache.
new cache -> A nested dictionary with the same shape/structure as the initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each
inputted cache. batch item.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables initial_cache: A dictionary, containing starting decoder variables
information. information.
vocab_size: An integer, the size of the vocabulary, used for topk vocab_size: An integer, the size of the vocabulary, used for topk
...@@ -67,8 +70,8 @@ def sequence_beam_search( ...@@ -67,8 +70,8 @@ def sequence_beam_search(
max_decode_length: An integer, the maximum length to decoded a sequence. max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has eos_id: An integer, ID of eos token, used to determine when a sequence has
finished. finished.
padded_decode: A bool, indicating if max_sequence_length padding is used padded_decode: A bool, indicating if max_sequence_length padding is used for
for beam search. beam search.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
......
...@@ -23,6 +23,7 @@ import random ...@@ -23,6 +23,7 @@ import random
import tarfile import tarfile
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
...@@ -64,22 +65,18 @@ _TRAIN_DATA_SOURCES = [ ...@@ -64,22 +65,18 @@ _TRAIN_DATA_SOURCES = [
# Use pre-defined minimum count to generate subtoken vocabulary. # Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT = 6 _TRAIN_DATA_MIN_COUNT = 6
_EVAL_DATA_SOURCES = [ _EVAL_DATA_SOURCES = [{
{ "url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz", "input": "newstest2013.en",
"input": "newstest2013.en", "target": "newstest2013.de",
"target": "newstest2013.de", }]
}
]
_TEST_DATA_SOURCES = [ _TEST_DATA_SOURCES = [{
{ "url": ("https://storage.googleapis.com/tf-perf-public/"
"url": ("https://storage.googleapis.com/tf-perf-public/" "official_transformer/test_data/newstest2014.tgz"),
"official_transformer/test_data/newstest2014.tgz"), "input": "newstest2014.en",
"input": "newstest2014.en", "target": "newstest2014.de",
"target": "newstest2014.de", }]
}
]
# Vocabulary constants # Vocabulary constants
_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list. _TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
...@@ -114,7 +111,9 @@ def find_file(path, filename, max_depth=5): ...@@ -114,7 +111,9 @@ def find_file(path, filename, max_depth=5):
# Download and extraction functions # Download and extraction functions
############################################################################### ###############################################################################
def get_raw_files(raw_dir, data_source): def get_raw_files(raw_dir, data_source):
"""Return raw files from source. Downloads/extracts if needed. """Return raw files from source.
Downloads/extracts if needed.
Args: Args:
raw_dir: string directory to store raw files raw_dir: string directory to store raw files
...@@ -134,8 +133,8 @@ def get_raw_files(raw_dir, data_source): ...@@ -134,8 +133,8 @@ def get_raw_files(raw_dir, data_source):
"targets": [], "targets": [],
} # keys } # keys
for d in data_source: for d in data_source:
input_file, target_file = download_and_extract( input_file, target_file = download_and_extract(raw_dir, d["url"],
raw_dir, d["url"], d["input"], d["target"]) d["input"], d["target"])
raw_files["inputs"].append(input_file) raw_files["inputs"].append(input_file)
raw_files["targets"].append(target_file) raw_files["targets"].append(target_file)
return raw_files return raw_files
...@@ -167,7 +166,7 @@ def download_from_url(path, url): ...@@ -167,7 +166,7 @@ def download_from_url(path, url):
found_file = find_file(path, filename, max_depth=0) found_file = find_file(path, filename, max_depth=0)
if found_file is None: if found_file is None:
filename = os.path.join(path, filename) filename = os.path.join(path, filename)
logging.info("Downloading from %s to %s." % (url, filename)) logging.info("Downloading from %s to %s.", url, filename)
inprogress_filepath = six.ensure_str(filename) + ".incomplete" inprogress_filepath = six.ensure_str(filename) + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve( inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook) url, inprogress_filepath, reporthook=download_report_hook)
...@@ -176,7 +175,7 @@ def download_from_url(path, url): ...@@ -176,7 +175,7 @@ def download_from_url(path, url):
tf.gfile.Rename(inprogress_filepath, filename) tf.gfile.Rename(inprogress_filepath, filename)
return filename return filename
else: else:
logging.info("Already downloaded: %s (at %s)." % (url, found_file)) logging.info("Already downloaded: %s (at %s).", url, found_file)
return found_file return found_file
...@@ -199,14 +198,14 @@ def download_and_extract(path, url, input_filename, target_filename): ...@@ -199,14 +198,14 @@ def download_and_extract(path, url, input_filename, target_filename):
input_file = find_file(path, input_filename) input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename) target_file = find_file(path, target_filename)
if input_file and target_file: if input_file and target_file:
logging.info("Already downloaded and extracted %s." % url) logging.info("Already downloaded and extracted %s.", url)
return input_file, target_file return input_file, target_file
# Download archive file if it doesn't already exist. # Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url) compressed_file = download_from_url(path, url)
# Extract compressed files # Extract compressed files
logging.info("Extracting %s." % compressed_file) logging.info("Extracting %s.", compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar: with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path) corpus_tar.extractall(path)
...@@ -236,13 +235,13 @@ def compile_files(raw_dir, raw_files, tag): ...@@ -236,13 +235,13 @@ def compile_files(raw_dir, raw_files, tag):
raw_files: Dict containing filenames of input and target data. raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language {"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language "targets": list of files containing corresponding data in target language
} }
tag: String to append to the compiled filename. tag: String to append to the compiled filename.
Returns: Returns:
Full path of compiled input and target files. Full path of compiled input and target files.
""" """
logging.info("Compiling files with tag %s." % tag) logging.info("Compiling files with tag %s.", tag)
filename = "%s-%s" % (_PREFIX, tag) filename = "%s-%s" % (_PREFIX, tag)
input_compiled_file = os.path.join(raw_dir, input_compiled_file = os.path.join(raw_dir,
six.ensure_str(filename) + ".lang1") six.ensure_str(filename) + ".lang1")
...@@ -255,7 +254,7 @@ def compile_files(raw_dir, raw_files, tag): ...@@ -255,7 +254,7 @@ def compile_files(raw_dir, raw_files, tag):
input_file = raw_files["inputs"][i] input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i] target_file = raw_files["targets"][i]
logging.info("Reading files %s and %s." % (input_file, target_file)) logging.info("Reading files %s and %s.", input_file, target_file)
write_file(input_writer, input_file) write_file(input_writer, input_file)
write_file(target_writer, target_file) write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file return input_compiled_file, target_compiled_file
...@@ -271,8 +270,7 @@ def write_file(writer, filename): ...@@ -271,8 +270,7 @@ def write_file(writer, filename):
############################################################################### ###############################################################################
# Data preprocessing # Data preprocessing
############################################################################### ###############################################################################
def encode_and_save_files( def encode_and_save_files(subtokenizer, data_dir, raw_files, tag, total_shards):
subtokenizer, data_dir, raw_files, tag, total_shards):
"""Save data from files as encoded Examples in TFrecord format. """Save data from files as encoded Examples in TFrecord format.
Args: Args:
...@@ -287,14 +285,16 @@ def encode_and_save_files( ...@@ -287,14 +285,16 @@ def encode_and_save_files(
List of all files produced. List of all files produced.
""" """
# Create a file for each shard. # Create a file for each shard.
filepaths = [shard_filename(data_dir, tag, n + 1, total_shards) filepaths = [
for n in range(total_shards)] shard_filename(data_dir, tag, n + 1, total_shards)
for n in range(total_shards)
]
if all_exist(filepaths): if all_exist(filepaths):
logging.info("Files with tag %s already exist." % tag) logging.info("Files with tag %s already exist.", tag)
return filepaths return filepaths
logging.info("Saving files with tag %s." % tag) logging.info("Saving files with tag %s.", tag)
input_file = raw_files[0] input_file = raw_files[0]
target_file = raw_files[1] target_file = raw_files[1]
...@@ -302,13 +302,14 @@ def encode_and_save_files( ...@@ -302,13 +302,14 @@ def encode_and_save_files(
tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths] tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths] writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
counter, shard = 0, 0 counter, shard = 0, 0
for counter, (input_line, target_line) in enumerate(zip( for counter, (input_line, target_line) in enumerate(
txt_line_iterator(input_file), txt_line_iterator(target_file))): zip(txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0: if counter > 0 and counter % 100000 == 0:
logging.info("\tSaving case %d." % counter) logging.info("\tSaving case %d.", counter)
example = dict_to_example( example = dict_to_example({
{"inputs": subtokenizer.encode(input_line, add_eos=True), "inputs": subtokenizer.encode(input_line, add_eos=True),
"targets": subtokenizer.encode(target_line, add_eos=True)}) "targets": subtokenizer.encode(target_line, add_eos=True)
})
writers[shard].write(example.SerializeToString()) writers[shard].write(example.SerializeToString())
shard = (shard + 1) % total_shards shard = (shard + 1) % total_shards
for writer in writers: for writer in writers:
...@@ -329,7 +330,7 @@ def shard_filename(path, tag, shard_num, total_shards): ...@@ -329,7 +330,7 @@ def shard_filename(path, tag, shard_num, total_shards):
def shuffle_records(fname): def shuffle_records(fname):
"""Shuffle records in a single file.""" """Shuffle records in a single file."""
logging.info("Shuffling records in file %s" % fname) logging.info("Shuffling records in file %s", fname)
# Rename file prior to shuffling # Rename file prior to shuffling
tmp_fname = six.ensure_str(fname) + ".unshuffled" tmp_fname = six.ensure_str(fname) + ".unshuffled"
...@@ -349,7 +350,7 @@ def shuffle_records(fname): ...@@ -349,7 +350,7 @@ def shuffle_records(fname):
for count, record in enumerate(records): for count, record in enumerate(records):
w.write(record) w.write(record)
if count > 0 and count % 100000 == 0: if count > 0 and count % 100000 == 0:
logging.info("\tWriting record: %d" % count) logging.info("\tWriting record: %d", count)
tf.gfile.Remove(tmp_fname) tf.gfile.Remove(tmp_fname)
...@@ -372,7 +373,7 @@ def all_exist(filepaths): ...@@ -372,7 +373,7 @@ def all_exist(filepaths):
def make_dir(path): def make_dir(path):
if not tf.gfile.Exists(path): if not tf.gfile.Exists(path):
logging.info("Creating directory %s" % path) logging.info("Creating directory %s", path)
tf.gfile.MakeDirs(path) tf.gfile.MakeDirs(path)
...@@ -395,7 +396,10 @@ def main(unused_argv): ...@@ -395,7 +396,10 @@ def main(unused_argv):
train_files_flat = train_files["inputs"] + train_files["targets"] train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE) vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_files( subtokenizer = tokenizer.Subtokenizer.init_from_files(
vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD, vocab_file,
train_files_flat,
_TARGET_VOCAB_SIZE,
_TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT) min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
logging.info("Step 4/5: Compiling training and evaluation data") logging.info("Step 4/5: Compiling training and evaluation data")
...@@ -404,12 +408,11 @@ def main(unused_argv): ...@@ -404,12 +408,11 @@ def main(unused_argv):
# Tokenize and save data as Examples in the TFRecord format. # Tokenize and save data as Examples in the TFRecord format.
logging.info("Step 5/5: Preprocessing and saving data") logging.info("Step 5/5: Preprocessing and saving data")
train_tfrecord_files = encode_and_save_files( train_tfrecord_files = encode_and_save_files(subtokenizer, FLAGS.data_dir,
subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG, compiled_train_files, _TRAIN_TAG,
_TRAIN_SHARDS) _TRAIN_SHARDS)
encode_and_save_files( encode_and_save_files(subtokenizer, FLAGS.data_dir, compiled_eval_files,
subtokenizer, FLAGS.data_dir, compiled_eval_files, _EVAL_TAG, _EVAL_TAG, _EVAL_SHARDS)
_EVAL_SHARDS)
for fname in train_tfrecord_files: for fname in train_tfrecord_files:
shuffle_records(fname) shuffle_records(fname)
...@@ -418,15 +421,20 @@ def main(unused_argv): ...@@ -418,15 +421,20 @@ def main(unused_argv):
def define_data_download_flags(): def define_data_download_flags():
"""Add flags specifying data download arguments.""" """Add flags specifying data download arguments."""
flags.DEFINE_string( flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp/translate_ende", name="data_dir",
short_name="dd",
default="/tmp/translate_ende",
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Directory for where the translate_ende_wmt32k dataset is saved.")) "Directory for where the translate_ende_wmt32k dataset is saved."))
flags.DEFINE_string( flags.DEFINE_string(
name="raw_dir", short_name="rd", default="/tmp/translate_ende_raw", name="raw_dir",
short_name="rd",
default="/tmp/translate_ende_raw",
help=flags_core.help_wrap( help=flags_core.help_wrap(
"Path where the raw data will be downloaded and extracted.")) "Path where the raw data will be downloaded and extracted."))
flags.DEFINE_bool( flags.DEFINE_bool(
name="search", default=False, name="search",
default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
"If set, use binary search to find the vocabulary set with size" "If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)) "closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
......
...@@ -87,8 +87,9 @@ def _parse_example(serialized_example): ...@@ -87,8 +87,9 @@ def _parse_example(serialized_example):
def _filter_max_length(example, max_length=256): def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length.""" """Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(tf.size(example[0]) <= max_length, return tf.logical_and(
tf.size(example[1]) <= max_length) tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example): def _get_example_length(example):
...@@ -97,8 +98,9 @@ def _get_example_length(example): ...@@ -97,8 +98,9 @@ def _get_example_length(example):
return length return length
def _create_min_max_boundaries( def _create_min_max_boundaries(max_length,
max_length, min_boundary=_MIN_BOUNDARY, boundary_scale=_BOUNDARY_SCALE): min_boundary=_MIN_BOUNDARY,
boundary_scale=_BOUNDARY_SCALE):
"""Create min and max boundary lists up to max_length. """Create min and max boundary lists up to max_length.
For example, when max_length=24, min_boundary=4 and boundary_scale=2, the For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
...@@ -165,8 +167,8 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -165,8 +167,8 @@ def _batch_examples(dataset, batch_size, max_length):
# TODO(xunkai): investigate if removing code branching improves performance. # TODO(xunkai): investigate if removing code branching improves performance.
conditions_c = tf.logical_and( conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length), tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
tf.less(seq_length, buckets_max)) buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c)) bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id return bucket_id
...@@ -183,16 +185,23 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -183,16 +185,23 @@ def _batch_examples(dataset, batch_size, max_length):
# lengths as well. Resulting lengths of inputs and targets can differ. # lengths as well. Resulting lengths of inputs and targets can differ.
return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None])) return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
return dataset.apply(tf.data.experimental.group_by_window( return dataset.apply(
key_func=example_to_bucket_id, tf.data.experimental.group_by_window(
reduce_func=batching_fn, key_func=example_to_bucket_id,
window_size=None, reduce_func=batching_fn,
window_size_func=window_size_fn)) window_size=None,
window_size_func=window_size_fn))
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, max_io_parallelism, shuffle, repeat, def _read_and_batch_from_files(file_pattern,
static_batch=False, num_replicas=1, ctx=None): batch_size,
max_length,
max_io_parallelism,
shuffle,
repeat,
static_batch=False,
num_replicas=1,
ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets". """Create dataset where each item is a dict of "inputs" and "targets".
Args: Args:
...@@ -204,20 +213,18 @@ def _read_and_batch_from_files( ...@@ -204,20 +213,18 @@ def _read_and_batch_from_files(
repeat: Number of times to repeat the dataset. If None, the dataset is repeat: Number of times to repeat the dataset. If None, the dataset is
repeated forever. repeated forever.
static_batch: Whether the batches in the dataset should have static shapes. static_batch: Whether the batches in the dataset should have static shapes.
If True, the input is batched so that every batch has the If True, the input is batched so that every batch has the shape
shape [batch_size // max_length, max_length]. If False, the input is [batch_size // max_length, max_length]. If False, the input is grouped by
grouped by length, and batched so that batches may have different length, and batched so that batches may have different
shapes [N, M], where: shapes [N, M], where: N * M <= batch_size M <= max_length In general, this
N * M <= batch_size setting should be False. Dynamic shapes allow the inputs to be grouped
M <= max_length so that the number of padding tokens is minimized, and helps model
In general, this setting should be False. Dynamic shapes allow the inputs training. In cases where the input shape must be static (e.g. running on
to be grouped so that the number of padding tokens is minimized, and helps TPU), this setting should be set to True.
model training. In cases where the input shape must be static
(e.g. running on TPU), this setting should be set to True.
num_replicas: Number of GPUs or other workers. We will generate global num_replicas: Number of GPUs or other workers. We will generate global
batches, and each global batch is equally divisible by number of replicas. batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False. effective when static_batch=False.
ctx: Input context. ctx: Input context.
Returns: Returns:
...@@ -240,8 +247,8 @@ def _read_and_batch_from_files( ...@@ -240,8 +247,8 @@ def _read_and_batch_from_files(
# Parse each tf.Example into a dictionary # Parse each tf.Example into a dictionary
# TODO: Look into prefetch_input_elements for performance optimization. # TODO: Look into prefetch_input_elements for performance optimization.
dataset = dataset.map(_parse_example, dataset = dataset.map(
num_parallel_calls=tf.data.experimental.AUTOTUNE) _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Remove examples where the input or target length exceeds the maximum length, # Remove examples where the input or target length exceeds the maximum length,
dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length)) dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
...@@ -252,7 +259,8 @@ def _read_and_batch_from_files( ...@@ -252,7 +259,8 @@ def _read_and_batch_from_files(
# into sentences, and finally expand to a global batch. It could prove # into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy. # the global batch divisble for distribution strategy.
int(batch_size // num_replicas // max_length * num_replicas), int(batch_size // num_replicas // max_length * num_replicas),
([max_length], [max_length]), drop_remainder=True) ([max_length], [max_length]),
drop_remainder=True)
else: else:
# Group and batch such that each batch has examples of similar length. # Group and batch such that each batch has examples of similar length.
# TODO(xunkai): _batch_examples might need to do something special for # TODO(xunkai): _batch_examples might need to do something special for
...@@ -291,10 +299,15 @@ def train_input_fn(params, ctx=None): ...@@ -291,10 +299,15 @@ def train_input_fn(params, ctx=None):
if params["use_synthetic_data"]: if params["use_synthetic_data"]:
return _generate_synthetic_data(params) return _generate_synthetic_data(params)
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"], file_pattern,
params["max_io_parallelism"], shuffle=True, params["batch_size"],
repeat=params["repeat_dataset"], static_batch=params["static_batch"], params["max_length"],
num_replicas=params["num_gpus"], ctx=ctx) params["max_io_parallelism"],
shuffle=True,
repeat=params["repeat_dataset"],
static_batch=params["static_batch"],
num_replicas=params["num_gpus"],
ctx=ctx)
def eval_input_fn(params, ctx=None): def eval_input_fn(params, ctx=None):
...@@ -303,9 +316,14 @@ def eval_input_fn(params, ctx=None): ...@@ -303,9 +316,14 @@ def eval_input_fn(params, ctx=None):
if params["use_synthetic_data"]: if params["use_synthetic_data"]:
return _generate_synthetic_data(params) return _generate_synthetic_data(params)
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"], file_pattern,
params["max_io_parallelism"], shuffle=False, repeat=1, params["batch_size"],
static_batch=params["static_batch"], num_replicas=params["num_gpus"], params["max_length"],
params["max_io_parallelism"],
shuffle=False,
repeat=1,
static_batch=params["static_batch"],
num_replicas=params["num_gpus"],
ctx=ctx) ctx=ctx)
......
...@@ -60,6 +60,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -60,6 +60,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args: Args:
inputs: An int64 tensor with shape [batch_size, length] inputs: An int64 tensor with shape [batch_size, length]
mode: string, a valid value is one of "embedding" and "linear". mode: string, a valid value is one of "embedding" and "linear".
Returns: Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output shape [batch_size, length, embedding_size]; (2) mode == "linear", output
...@@ -82,7 +83,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -82,7 +83,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype) mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
embeddings *= tf.expand_dims(mask, -1) embeddings *= tf.expand_dims(mask, -1)
# Scale embedding by the sqrt of the hidden size # Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size ** 0.5 embeddings *= self.hidden_size**0.5
return embeddings return embeddings
...@@ -91,6 +92,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -91,6 +92,7 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
Args: Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size] inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns: Returns:
float32 tensor with shape [batch_size, length, vocab_size]. float32 tensor with shape [batch_size, length, vocab_size].
""" """
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -66,28 +67,34 @@ def define_transformer_flags(): ...@@ -66,28 +67,34 @@ def define_transformer_flags():
tf_gpu_thread_mode=True, tf_gpu_thread_mode=True,
datasets_num_private_threads=True, datasets_num_private_threads=True,
enable_xla=True, enable_xla=True,
fp16_implementation=True fp16_implementation=True)
)
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
flags.DEFINE_integer( flags.DEFINE_integer(
name='train_steps', short_name='ts', default=300000, name='train_steps',
short_name='ts',
default=300000,
help=flags_core.help_wrap('The number of steps used to train.')) help=flags_core.help_wrap('The number of steps used to train.'))
flags.DEFINE_integer( flags.DEFINE_integer(
name='steps_between_evals', short_name='sbe', default=5000, name='steps_between_evals',
short_name='sbe',
default=5000,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'The Number of training steps to run between evaluations. This is ' 'The Number of training steps to run between evaluations. This is '
'used if --train_steps is defined.')) 'used if --train_steps is defined.'))
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_time_history', default=True, name='enable_time_history',
default=True,
help='Whether to enable TimeHistory callback.') help='Whether to enable TimeHistory callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='enable_tensorboard',
default=False,
help='Whether to enable Tensorboard callback.') help='Whether to enable Tensorboard callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_metrics_in_training', default=False, name='enable_metrics_in_training',
default=False,
help='Whether to enable metrics during training.') help='Whether to enable metrics during training.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_mlir_bridge', name='enable_mlir_bridge',
...@@ -100,7 +107,9 @@ def define_transformer_flags(): ...@@ -100,7 +107,9 @@ def define_transformer_flags():
# Add transformer-specific flags # Add transformer-specific flags
flags.DEFINE_enum( flags.DEFINE_enum(
name='param_set', short_name='mp', default='big', name='param_set',
short_name='mp',
default='big',
enum_values=PARAMS_MAP.keys(), enum_values=PARAMS_MAP.keys(),
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Parameter set to use when creating and training the model. The ' 'Parameter set to use when creating and training the model. The '
...@@ -111,7 +120,9 @@ def define_transformer_flags(): ...@@ -111,7 +120,9 @@ def define_transformer_flags():
'complete list of parameters, please see model/model_params.py.')) 'complete list of parameters, please see model/model_params.py.'))
flags.DEFINE_bool( flags.DEFINE_bool(
name='static_batch', short_name='sb', default=False, name='static_batch',
short_name='sb',
default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the batches in the dataset should have static shapes. In ' 'Whether the batches in the dataset should have static shapes. In '
'general, this setting should be False. Dynamic shapes allow the ' 'general, this setting should be False. Dynamic shapes allow the '
...@@ -120,7 +131,9 @@ def define_transformer_flags(): ...@@ -120,7 +131,9 @@ def define_transformer_flags():
'must be static (e.g. running on TPU), this setting will be ignored ' 'must be static (e.g. running on TPU), this setting will be ignored '
'and static batching will always be used.')) 'and static batching will always be used.'))
flags.DEFINE_integer( flags.DEFINE_integer(
name='max_length', short_name='ml', default=256, name='max_length',
short_name='ml',
default=256,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Max sentence length for Transformer. Default is 256. Note: Usually ' 'Max sentence length for Transformer. Default is 256. Note: Usually '
'it is more effective to use a smaller max length if static_batch is ' 'it is more effective to use a smaller max length if static_batch is '
...@@ -128,30 +141,39 @@ def define_transformer_flags(): ...@@ -128,30 +141,39 @@ def define_transformer_flags():
# Flags for training with steps (may be used for debugging) # Flags for training with steps (may be used for debugging)
flags.DEFINE_integer( flags.DEFINE_integer(
name='validation_steps', short_name='vs', default=64, name='validation_steps',
short_name='vs',
default=64,
help=flags_core.help_wrap('The number of steps used in validation.')) help=flags_core.help_wrap('The number of steps used in validation.'))
# BLEU score computation # BLEU score computation
flags.DEFINE_string( flags.DEFINE_string(
name='bleu_source', short_name='bls', default=None, name='bleu_source',
short_name='bls',
default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the ' 'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
)) ))
flags.DEFINE_string( flags.DEFINE_string(
name='bleu_ref', short_name='blr', default=None, name='bleu_ref',
short_name='blr',
default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the ' 'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
)) ))
flags.DEFINE_string( flags.DEFINE_string(
name='vocab_file', short_name='vf', default=None, name='vocab_file',
short_name='vf',
default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Path to subtoken vocabulary file. If data_download.py was used to ' 'Path to subtoken vocabulary file. If data_download.py was used to '
'download and encode the training data, look in the data_dir to find ' 'download and encode the training data, look in the data_dir to find '
'the vocab file.')) 'the vocab file.'))
flags.DEFINE_string( flags.DEFINE_string(
name='mode', default='train', name='mode',
default='train',
help=flags_core.help_wrap('mode: train, eval, or predict')) help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool( flags.DEFINE_bool(
name='use_ctl', name='use_ctl',
...@@ -188,9 +210,10 @@ def define_transformer_flags(): ...@@ -188,9 +210,10 @@ def define_transformer_flags():
'Whether to do checkpointing during training. When running under ' 'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.')) 'benchmark harness, we will avoid checkpointing.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(
model_dir='/tmp/transformer_model', data_dir='/tmp/translate_ende',
batch_size=None) model_dir='/tmp/transformer_model',
batch_size=None)
# pylint: disable=unused-variable # pylint: disable=unused-variable
@flags.multi_flags_validator( @flags.multi_flags_validator(
...@@ -203,11 +226,12 @@ def define_transformer_flags(): ...@@ -203,11 +226,12 @@ def define_transformer_flags():
@flags.multi_flags_validator( @flags.multi_flags_validator(
['bleu_source', 'bleu_ref', 'vocab_file'], ['bleu_source', 'bleu_ref', 'vocab_file'],
message='--vocab_file must be defined if --bleu_source and --bleu_ref ' message='--vocab_file must be defined if --bleu_source and --bleu_ref '
'are defined.') 'are defined.')
def _check_bleu_vocab_file(flags_dict): def _check_bleu_vocab_file(flags_dict):
if flags_dict['bleu_source'] and flags_dict['bleu_ref']: if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
return flags_dict['vocab_file'] is not None return flags_dict['vocab_file'] is not None
return True return True
# pylint: enable=unused-variable # pylint: enable=unused-variable
...@@ -256,5 +280,5 @@ def update_stats(history, stats, callbacks): ...@@ -256,5 +280,5 @@ def update_stats(history, stats, callbacks):
if len(timestamp_log) > 1: if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = ( stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps * callback.batch_size * callback.log_steps *
(len(callback.timestamp_log)-1) / (len(callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp)) (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment