Commit bf748370 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 7c732da7 0d2c2e01
......@@ -21,6 +21,7 @@ import os
import time
from absl import flags
import tensorflow as tf
from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main
......@@ -30,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS
TMP_DIR = os.getenv('TMPDIR')
class TransformerBenchmark(PerfZeroBenchmark):
......@@ -56,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de')
default_flags['train_steps'] = 200
default_flags['log_steps'] = 10
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
......@@ -280,8 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
bleu_min=27.9,
bleu_max=29.2)
def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu.
......@@ -305,12 +312,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
bleu_max=29.2)
def benchmark_8_gpu_fp16(self):
"""Benchmark 8 gpu with dynamic batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
Over 6 runs with eval every 20K steps the average highest value was 28.247
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration
11. While this could be interpreted as worse than FP32, if looking at the
first iteration at which 28 is passed FP16 performs equal and possibly
better. Although not part of the initial test runs, the highest value
recorded with the arguments below was 28.9 at iteration 12. Iterations are
not epochs, an iteration is a number of steps between evals.
"""
self._setup()
FLAGS.num_gpus = 8
......@@ -328,7 +342,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
bleu_max=29.2)
def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16.
......@@ -353,7 +367,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
bleu_max=29.2)
def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch, XLA, and FP16.
......@@ -380,7 +394,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
bleu_min=28,
bleu_max=29)
bleu_max=29.2)
class TransformerKerasBenchmark(TransformerBenchmark):
......@@ -611,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['param_set'] = 'base'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBaseKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags,
......@@ -633,20 +637,14 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
def_flags = {}
def_flags['param_set'] = 'big'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBigKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags,
root_data_dir=root_data_dir, batch_per_gpu=3072)
if __name__ == '__main__':
tf.test.main()
......@@ -27,12 +27,16 @@ import tempfile
from absl import app as absl_app # pylint: disable=unused-import
from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow.python.util import object_identity
# pylint: disable=g-bad-import-order
from official.transformer import compute_bleu
from official.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline
from official.transformer.v2 import metrics
from official.transformer.v2 import misc
from official.transformer.v2 import optimizer
from official.transformer.v2 import transformer
......@@ -48,18 +52,40 @@ BLEU_DIR = "bleu"
_SINGLE_SAMPLE = 1
def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
"""Translate file and report the cased and uncased bleu scores."""
def translate_and_compute_bleu(model,
params,
subtokenizer,
bleu_source,
bleu_ref,
distribution_strategy=None):
"""Translate file and report the cased and uncased bleu scores.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
# Create temporary file to store translation.
tmp = tempfile.NamedTemporaryFile(delete=False)
tmp_filename = tmp.name
translate.translate_file(
model,
params,
subtokenizer,
bleu_source,
output_file=tmp_filename,
print_all_translations=False)
print_all_translations=False,
distribution_strategy=distribution_strategy)
# Compute uncased and cased bleu scores.
uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
......@@ -68,15 +94,34 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
return uncased_score, cased_score
def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file):
"""Calculate and record the BLEU score."""
def evaluate_and_log_bleu(model,
params,
bleu_source,
bleu_ref,
vocab_file,
distribution_strategy=None):
"""Calculate and record the BLEU score.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
vocab_file: A file containing the vocabulary for translation.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
subtokenizer = tokenizer.Subtokenizer(vocab_file)
uncased_score, cased_score = translate_and_compute_bleu(
model, subtokenizer, bleu_source, bleu_ref)
model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy)
tf.compat.v1.logging.info("Bleu score (uncased): %s", uncased_score)
tf.compat.v1.logging.info("Bleu score (cased): %s", cased_score)
logging.info("Bleu score (uncased): %s", uncased_score)
logging.info("Bleu score (cased): %s", cased_score)
return uncased_score, cased_score
......@@ -88,30 +133,27 @@ class TransformerTask(object):
Args:
flags_obj: Object containing parsed flag values, i.e., FLAGS.
Raises:
ValueError: if not using static batch for input data on TPU.
"""
self.flags_obj = flags_obj
self.predict_model = None
# Add flag-defined parameters to params object
num_gpus = flags_core.get_num_gpus(flags_obj)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch
params["max_length"] = flags_obj.max_length
params["decode_batch_size"] = flags_obj.decode_batch_size
params["decode_max_length"] = flags_obj.decode_max_length
params["padded_decode"] = flags_obj.padded_decode
params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
......@@ -130,33 +172,114 @@ class TransformerTask(object):
"infer_float32_vars")
tf.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else:
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
@property
def use_tpu(self):
if self.distribution_strategy:
return isinstance(self.distribution_strategy,
tf.distribute.experimental.TPUStrategy)
return False
def train(self):
"""Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True
params = self.params
flags_obj = self.flags_obj
# Sets config options.
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir)
if self.distribution_strategy:
with self.distribution_strategy.scope():
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
model.compile(opt)
else:
model = transformer.create_model(params, is_train)
with distribution_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer()
model.compile(opt)
if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
else:
model.compile(opt)
model.summary()
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(map_data_fn,
num_parallel_calls=params["num_parallel_calls"])
if self.use_tpu:
# Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires
# per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params)))
else:
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
"""Training steps function for TPU runs.
Args:
iterator: The input iterator of the training dataset.
steps: An integer, the number of training steps.
Returns:
A float, the loss value.
"""
def _step_fn(inputs):
"""Per-replica step function."""
inputs, targets = inputs
with tf.GradientTape() as tape:
logits = model([inputs, targets], training=True)
loss = metrics.transformer_loss(logits, targets,
params["label_smoothing"],
params["vocab_size"])
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
# De-dupes variables due to keras tracking issues.
tvars = list(
object_identity.ObjectIdentitySet(model.trainable_variables))
grads = tape.gradient(scaled_loss, tvars)
opt.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for _ in tf.range(steps):
train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),))
if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals
......@@ -165,28 +288,54 @@ class TransformerTask(object):
cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1):
print("Start train iteration:{}/{}".format(i, iterations))
history = model.fit(
train_ds,
initial_epoch=i-1,
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase the
# verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
history = None
if params["use_ctl"]:
if not self.use_tpu:
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")
train_steps_per_eval = tf.convert_to_tensor(
flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps.
train_steps(train_ds_iterator, train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps,
train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.")
history = model.fit(
train_ds,
initial_epoch=i - 1,
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format(
i,
iterations,
i*flags_obj.steps_between_evals))
tf.compat.v1.logging.info("Train history: {}".format(history.history))
stats = misc.build_stats(history, callbacks)
if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score])
uncased_score_history.append([i, uncased_score])
stats = misc.build_stats(history, callbacks)
stats = ({
"loss": train_loss
} if history is None else misc.build_stats(history, callbacks))
if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score
stats["bleu_cased"] = cased_score
......@@ -202,17 +351,18 @@ class TransformerTask(object):
self.predict_model,
tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary()
return evaluate_and_log_bleu(self.predict_model,
self.flags_obj.bleu_source,
self.flags_obj.bleu_ref,
self.flags_obj.vocab_file)
return evaluate_and_log_bleu(
self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
self.distribution_strategy if self.use_tpu else None)
def predict(self):
"""Predicts result from the model."""
params, flags_obj, is_train = self.params, self.flags_obj, False
params = self.params
flags_obj = self.flags_obj
with tf.name_scope("model"):
model = transformer.create_model(params, is_train)
model = transformer.create_model(params, is_train=False)
self._load_weights_if_possible(
model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
model.summary()
......@@ -242,16 +392,28 @@ class TransformerTask(object):
def _load_weights_if_possible(self, model, init_weight_path=None):
"""Loads model weights when it is provided."""
if init_weight_path:
tf.compat.v1.logging.info("Load weights: {}".format(init_weight_path))
model.load_weights(init_weight_path)
logging.info("Load weights: {}".format(init_weight_path))
# TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if self.use_tpu:
checkpoint = tf.train.Checkpoint(
model=model, optimizer=self._create_optimizer())
checkpoint.restore(init_weight_path)
else:
model.load_weights(init_weight_path)
else:
print("Weights not loaded from path:{}".format(init_weight_path))
def _create_optimizer(self):
"""Creates optimizer."""
params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"])
opt = tf.keras.optimizers.Adam(
params["learning_rate"],
lr_schedule if self.use_tpu else params["learning_rate"],
params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
......@@ -264,25 +426,34 @@ class TransformerTask(object):
def _ensure_dir(log_dir):
"""Makes log dir if not existed."""
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not tf.io.gfile.exists(log_dir):
tf.io.gfile.makedirs(log_dir)
def main(_):
flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj)
if flags_obj.mode == "train":
task.train()
elif flags_obj.mode == "predict":
task.predict()
elif flags_obj.mode == "eval":
task.eval()
def _run_task(task):
if flags_obj.mode == "train":
task.train()
elif flags_obj.mode == "predict":
task.predict()
elif flags_obj.mode == "eval":
task.eval()
else:
raise ValueError("Invalid mode {}".format(flags_obj.mode))
if not flags_obj.distribution_strategy != "tpu":
_run_task(task)
else:
raise ValueError("Invalid mode {}".format(flags_obj.mode))
primary_cpu_task = "/job:worker" if flags_obj.use_tpu_2vm_config else ""
with tf.device(primary_cpu_task):
_run_task(task)
if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
logging.set_verbosity(logging.INFO)
misc.define_transformer_flags()
absl_app.run(main)
......@@ -30,7 +30,7 @@ from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as tm
from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp'
......@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase):
self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device'
if tf.test.is_built_with_cuda():
FLAGS.num_gpus = 1
else:
FLAGS.num_gpus = 0
FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS)
t.train()
......@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase):
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS)
t.train()
......@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
......@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'.
format(2, context.num_gpus()))
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
......@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS(update_flags)
def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
t.predict()
def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS)
t.predict()
def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS)
t.eval()
......
......@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import values
from official.transformer.utils import tokenizer
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4
_ALPHA = 0.6
......@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
return subtokenizer.decode(ids)
def translate_file(
model, subtokenizer, input_file, output_file=None,
print_all_translations=True):
def translate_file(model,
params,
subtokenizer,
input_file,
output_file=None,
print_all_translations=True,
distribution_strategy=None):
"""Translate lines in file, and save to output file if specified.
Args:
model: Keras model used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and
translated lines.
input_file: file containing lines to translate
output_file: file that stores the generated translations.
print_all_translations: If true, all translations are printed to stdout.
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
input_file: A file containing lines to translate.
output_file: A file that stores the generated translations.
print_all_translations: A bool. If true, all translations are printed to
stdout.
distribution_strategy: A distribution strategy, used to perform inference
directly with tf.function instead of Keras model.predict().
Raises:
ValueError: if output file is invalid.
"""
batch_size = _DECODE_BATCH_SIZE
batch_size = params["decode_batch_size"]
# Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order.
......@@ -101,24 +110,59 @@ def translate_file(
if j + i * batch_size < total_samples
]
lines = [_encode_and_add_eos(l, subtokenizer) for l in lines]
if distribution_strategy:
for j in range(batch_size - len(lines)):
lines.append([tokenizer.EOS_ID])
batch = tf.keras.preprocessing.sequence.pad_sequences(
lines, dtype="int64", padding="post")
lines,
maxlen=params["decode_max_length"],
dtype="int32",
padding="post")
tf.compat.v1.logging.info("Decoding batch %d out of %d.", i,
num_decode_batches)
yield batch
@tf.function
def predict_step(inputs):
"""Decoding step function for TPU runs."""
def _step_fn(inputs):
"""Per replica step function."""
val_outputs, _ = model([inputs], training=False)
return val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
translations = []
if distribution_strategy:
num_replicas = distribution_strategy.num_replicas_in_sync
local_batch_size = params["decode_batch_size"] // num_replicas
for i, text in enumerate(input_generator()):
val_outputs, _ = model.predict(text)
if distribution_strategy:
text = np.reshape(text, [num_replicas, local_batch_size, -1])
text = [
tf.convert_to_tensor(per_replica_text) for per_replica_text in text
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
# pylint: enable=protected-access
val_outputs = distribution_strategy.experimental_local_results(
predict_step(text))
val_outputs = np.reshape(
[val_output.numpy() for val_output in val_outputs],
[params["decode_batch_size"], -1])
else:
val_outputs, _ = model.predict(text)
length = len(val_outputs)
for j in range(length):
translation = _trim_and_decode(val_outputs[j], subtokenizer)
translations.append(translation)
if print_all_translations:
tf.compat.v1.logging.info(
"Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[j + i * batch_size], translation))
if j + i * batch_size < total_samples:
translation = _trim_and_decode(val_outputs[j], subtokenizer)
translations.append(translation)
if print_all_translations:
tf.compat.v1.logging.info(
"Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[j + i * batch_size], translation))
# Write translations in the order they appeared in the original file.
if output_file is not None:
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags related to distributed execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def define_distribution(worker_hosts=True, task_index=True):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if worker_hosts:
flags.DEFINE_string(
name='worker_hosts', default=None,
help=help_wrap(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'))
if task_index:
flags.DEFINE_integer(
name='task_index', default=-1,
help=help_wrap('If multi-worker training, the task_index of this '
'worker.'))
return key_flags
......@@ -53,8 +53,8 @@ def get_loss_scale(flags_obj, default_for_fp16):
return default_for_fp16
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True,
def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
synthetic_data=True, max_train_steps=False, dtype=True,
all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
......
......@@ -32,6 +32,7 @@ from official.utils.flags import _base
from official.utils.flags import _benchmark
from official.utils.flags import _conventions
from official.utils.flags import _device
from official.utils.flags import _distribution
from official.utils.flags import _misc
from official.utils.flags import _performance
......@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance)
define_distribution = register_key_flags_in_core(
_distribution.define_distribution)
help_wrap = _conventions.help_wrap
......
......@@ -23,7 +23,9 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(num_gpu=False)
flags_core.define_performance(dynamic_loss_scale=True, loss_scale=True)
flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale=True, loss_scale=True)
flags_core.define_image()
flags_core.define_benchmark()
......
......@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default",
return None
if distribution_strategy == "tpu":
if not tpu_address:
raise ValueError("`tpu_address` must be specified when using "
"TPUStrategy.")
# When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
......@@ -205,38 +202,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
self._initializers = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self.random_name(),
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
self._initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset):
def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
......@@ -244,22 +267,34 @@ def _monkey_patch_dataset_method(strategy):
else:
return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_dataset_iterator
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'org_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
......@@ -267,10 +302,14 @@ def set_up_synthetic_data():
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
......
......@@ -29,7 +29,7 @@ from absl import flags
from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
......@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
"""
extra_flags = [] if extra_flags is None else extra_flags
......@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if synth:
args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try:
flags_core.parse_flags(argv=args)
main(flags.FLAGS)
......
[MESSAGES CONTROL]
disable=R,W,bad-option-value,trailing-newlines
disable=R,W,bad-option-value,trailing-newlines,no-name-in-module
[REPORTS]
# Tells whether to display a full report or only the messages
......
This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../README.md).
This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../../README.md).
Similar to the [estimator implementation](/official/resnet), the Keras
Similar to the [estimator implementation](../../r1/resnet), the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
uses a ResNet50 model implemented in [`resnet_model.py`](./resnet_model.py).
To use
either dataset, make sure that you have the latest version of TensorFlow
installed and
To use
either dataset, make sure that you have the latest version of TensorFlow
installed and
[add the models folder to your Python path](/official/#running-the-models),
otherwise you may encounter an error like `ImportError: No module named
otherwise you may encounter an error like `ImportError: No module named
official.resnet`.
## CIFAR-10
Download and extract the CIFAR-10 data. You can use the following script:
```bash
python cifar10_download_and_extract.py
python ../../r1/resnet/cifar10_download_and_extract.py
```
After you download the data, you can run the program by:
```bash
python keras_cifar_main.py
python resnet_cifar_main.py
```
If you did not use the default directory to download the data, specify the
If you did not use the default directory to download the data, specify the
location with the `--data_dir` flag, like:
```bash
python keras_cifar_main.py --data_dir=/path/to/cifar
python resnet_cifar_main.py --data_dir=/path/to/cifar
```
## ImageNet
Download the ImageNet dataset and convert it to TFRecord format.
Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options.
......@@ -44,57 +44,57 @@ provide a few options.
Once your dataset is ready, you can begin training the model as follows:
```bash
python keras_imagenet_main.py
python resnet_imagenet_main.py
```
Again, if you did not download the data to the default directory, specify the
location with the `--data_dir` flag:
```bash
python keras_imagenet_main.py --data_dir=/path/to/imagenet
python resnet_imagenet_main.py --data_dir=/path/to/imagenet
```
There are more flag options you can specify. Here are some examples:
- `--use_synthetic_data`: when set to true, synthetic data, rather than real
- `--use_synthetic_data`: when set to true, synthetic data, rather than real
data, are used;
- `--batch_size`: the batch size used for the model;
- `--model_dir`: the directory to save the model checkpoint;
- `--train_epochs`: number of epoches to run for training the model;
- `--train_steps`: number of steps to run for training the model. We now only
support a number that is smaller than the number of batches in an epoch.
- `--skip_eval`: when set to true, evaluation as well as validation during
- `--skip_eval`: when set to true, evaluation as well as validation during
training is skipped
For example, this is a typical command line to run with ImageNet data with
For example, this is a typical command line to run with ImageNet data with
batch size 128 per GPU:
```bash
python -m keras_imagenet_main \
--model_dir=/tmp/model_dir/something \
--num_gpus=2 \
--batch_size=128 \
--train_epochs=90 \
--train_steps=10 \
--use_synthetic_data=false
python -m resnet_imagenet_main \
--model_dir=/tmp/model_dir/something \
--num_gpus=2 \
--batch_size=128 \
--train_epochs=90 \
--train_steps=10 \
--use_synthetic_data=false
```
See [`keras_common.py`](keras_common.py) for full list of options.
See [`common.py`](common.py) for full list of options.
## Using multiple GPUs
You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this
You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this
[guide](https://www.tensorflow.org/guide/distribute_strategy).
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise.
- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs.
If you wish to run without `tf.distribute.Strategy`, you can do so by setting
If you wish to run without `tf.distribute.Strategy`, you can do so by setting
`--distribution_strategy=off`.
......@@ -22,7 +22,7 @@ import os
from absl import logging
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
from official.vision.image_classification import imagenet_preprocessing
HEIGHT = 32
WIDTH = 32
......
......@@ -20,17 +20,13 @@ from __future__ import print_function
import multiprocessing
import os
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
......@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
force_v2_in_keras_compile=True)
flags_core.define_image()
flags_core.define_benchmark()
flags_core.define_distribution()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
......
......@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the keras_common module."""
"""Tests for the common module."""
from __future__ import absolute_import
from __future__ import print_function
from mock import Mock
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.platform import googletest
import tensorflow as tf
from official.resnet.keras import keras_common
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.vision.image_classification import common
class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common."""
"""Tests for common."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
......@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, [th])
stats = common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None)
stats = common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1'])
......
......@@ -22,13 +22,13 @@ from absl import app as absl_app
from absl import flags
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_cifar_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate.
"""
del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128
initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch:
......@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode()
common.set_gpu_thread_mode_and_count(flags_obj)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16':
......@@ -116,7 +116,7 @@ def run(flags_obj):
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS,
......@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope:
optimizer = keras_common.get_optimizer()
optimizer = common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
......@@ -171,7 +171,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -216,12 +216,12 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False)
common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model',
......
......@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tempfile import mkdtemp
import tempfile
import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import resnet_cifar_main
class KerasCifarTest(googletest.TestCase):
......@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def get_temp_dir(self):
if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass()
keras_cifar_main.define_cifar_flags()
resnet_cifar_main.define_cifar_flags()
def setUp(self):
super(KerasCifarTest, self).setUp()
......@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags
integration.run_synthetic(
main=keras_cifar_main.run,
main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment