Commit bf748370 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 7c732da7 0d2c2e01
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import time import time
from absl import flags from absl import flags
import tensorflow as tf
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as transformer_main from official.transformer.v2 import transformer_main as transformer_main
...@@ -30,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark ...@@ -30,6 +31,7 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
TMP_DIR = os.getenv('TMPDIR')
class TransformerBenchmark(PerfZeroBenchmark): class TransformerBenchmark(PerfZeroBenchmark):
...@@ -56,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -56,6 +58,11 @@ class TransformerBenchmark(PerfZeroBenchmark):
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
default_flags['train_steps'] = 200
default_flags['log_steps'] = 10
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__( super(TransformerBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=default_flags, default_flags=default_flags,
...@@ -280,8 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -280,8 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=27.9,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_static_batch(self): def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu. """Benchmark 8 gpu.
...@@ -305,12 +312,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -305,12 +312,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
"""Benchmark 8 gpu with dynamic batch and fp16. """Benchmark 8 gpu with dynamic batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Over 6 runs with eval every 20K steps the average highest value was 28.247
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration
11. While this could be interpreted as worse than FP32, if looking at the
first iteration at which 28 is passed FP16 performs equal and possibly
better. Although not part of the initial test runs, the highest value
recorded with the arguments below was 28.9 at iteration 12. Iterations are
not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
...@@ -328,7 +342,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -328,7 +342,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_static_batch_fp16(self): def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16. """Benchmark 8 gpu with static batch and fp16.
...@@ -353,7 +367,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -353,7 +367,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_xla_8_gpu_static_batch_fp16(self): def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch, XLA, and FP16. """Benchmark 8 gpu with static batch, XLA, and FP16.
...@@ -380,7 +394,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -380,7 +394,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
class TransformerKerasBenchmark(TransformerBenchmark): class TransformerKerasBenchmark(TransformerBenchmark):
...@@ -611,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark): ...@@ -611,19 +625,9 @@ class TransformerKerasBenchmark(TransformerBenchmark):
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def_flags = {} def_flags = {}
def_flags['param_set'] = 'base' def_flags['param_set'] = 'base'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBaseKerasBenchmarkReal, self).__init__( super(TransformerBaseKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, output_dir=output_dir, default_flags=def_flags,
...@@ -633,20 +637,14 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -633,20 +637,14 @@ class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests.""" """Transformer based version real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs): def __init__(self, output_dir=TMP_DIR, root_data_dir=None, **kwargs):
train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
def_flags = {} def_flags = {}
def_flags['param_set'] = 'big' def_flags['param_set'] = 'big'
def_flags['vocab_file'] = vocab_file
def_flags['data_dir'] = train_data_dir
def_flags['train_steps'] = 200
def_flags['log_steps'] = 10
super(TransformerBigKerasBenchmarkReal, self).__init__( super(TransformerBigKerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, output_dir=output_dir, default_flags=def_flags,
root_data_dir=root_data_dir, batch_per_gpu=3072) root_data_dir=root_data_dir, batch_per_gpu=3072)
if __name__ == '__main__':
tf.test.main()
...@@ -27,12 +27,16 @@ import tempfile ...@@ -27,12 +27,16 @@ import tempfile
from absl import app as absl_app # pylint: disable=unused-import from absl import app as absl_app # pylint: disable=unused-import
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline from official.transformer.v2 import data_pipeline
from official.transformer.v2 import metrics
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import optimizer from official.transformer.v2 import optimizer
from official.transformer.v2 import transformer from official.transformer.v2 import transformer
...@@ -48,18 +52,40 @@ BLEU_DIR = "bleu" ...@@ -48,18 +52,40 @@ BLEU_DIR = "bleu"
_SINGLE_SAMPLE = 1 _SINGLE_SAMPLE = 1
def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref): def translate_and_compute_bleu(model,
"""Translate file and report the cased and uncased bleu scores.""" params,
subtokenizer,
bleu_source,
bleu_ref,
distribution_strategy=None):
"""Translate file and report the cased and uncased bleu scores.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
# Create temporary file to store translation. # Create temporary file to store translation.
tmp = tempfile.NamedTemporaryFile(delete=False) tmp = tempfile.NamedTemporaryFile(delete=False)
tmp_filename = tmp.name tmp_filename = tmp.name
translate.translate_file( translate.translate_file(
model, model,
params,
subtokenizer, subtokenizer,
bleu_source, bleu_source,
output_file=tmp_filename, output_file=tmp_filename,
print_all_translations=False) print_all_translations=False,
distribution_strategy=distribution_strategy)
# Compute uncased and cased bleu scores. # Compute uncased and cased bleu scores.
uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False) uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
...@@ -68,15 +94,34 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref): ...@@ -68,15 +94,34 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
return uncased_score, cased_score return uncased_score, cased_score
def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file): def evaluate_and_log_bleu(model,
"""Calculate and record the BLEU score.""" params,
bleu_source,
bleu_ref,
vocab_file,
distribution_strategy=None):
"""Calculate and record the BLEU score.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
vocab_file: A file containing the vocabulary for translation.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
subtokenizer = tokenizer.Subtokenizer(vocab_file) subtokenizer = tokenizer.Subtokenizer(vocab_file)
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
model, subtokenizer, bleu_source, bleu_ref) model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy)
tf.compat.v1.logging.info("Bleu score (uncased): %s", uncased_score) logging.info("Bleu score (uncased): %s", uncased_score)
tf.compat.v1.logging.info("Bleu score (cased): %s", cased_score) logging.info("Bleu score (cased): %s", cased_score)
return uncased_score, cased_score return uncased_score, cased_score
...@@ -88,30 +133,27 @@ class TransformerTask(object): ...@@ -88,30 +133,27 @@ class TransformerTask(object):
Args: Args:
flags_obj: Object containing parsed flag values, i.e., FLAGS. flags_obj: Object containing parsed flag values, i.e., FLAGS.
Raises:
ValueError: if not using static batch for input data on TPU.
""" """
self.flags_obj = flags_obj self.flags_obj = flags_obj
self.predict_model = None self.predict_model = None
# Add flag-defined parameters to params object # Add flag-defined parameters to params object
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
params["max_length"] = flags_obj.max_length params["max_length"] = flags_obj.max_length
params["decode_batch_size"] = flags_obj.decode_batch_size
params["decode_max_length"] = flags_obj.decode_max_length
params["padded_decode"] = flags_obj.padded_decode
params["num_parallel_calls"] = ( params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
...@@ -130,33 +172,114 @@ class TransformerTask(object): ...@@ -130,33 +172,114 @@ class TransformerTask(object):
"infer_float32_vars") "infer_float32_vars")
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else:
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
@property
def use_tpu(self):
if self.distribution_strategy:
return isinstance(self.distribution_strategy,
tf.distribute.experimental.TPUStrategy)
return False
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True params = self.params
flags_obj = self.flags_obj
# Sets config options. # Sets config options.
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
if self.distribution_strategy: with distribution_utils.get_strategy_scope(self.distribution_strategy):
with self.distribution_strategy.scope(): model = transformer.create_model(params, is_train=True)
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
model.compile(opt)
else:
model = transformer.create_model(params, is_train)
opt = self._create_optimizer() opt = self._create_optimizer()
model.compile(opt) if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
else:
model.compile(opt)
model.summary() model.summary()
train_ds = data_pipeline.train_input_fn(params) if self.use_tpu:
map_data_fn = data_pipeline.map_data_for_transformer_fn # Different from experimental_distribute_dataset,
train_ds = train_ds.map(map_data_fn, # experimental_distribute_datasets_from_function requires
num_parallel_calls=params["num_parallel_calls"]) # per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params)))
else:
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
"""Training steps function for TPU runs.
Args:
iterator: The input iterator of the training dataset.
steps: An integer, the number of training steps.
Returns:
A float, the loss value.
"""
def _step_fn(inputs):
"""Per-replica step function."""
inputs, targets = inputs
with tf.GradientTape() as tape:
logits = model([inputs, targets], training=True)
loss = metrics.transformer_loss(logits, targets,
params["label_smoothing"],
params["vocab_size"])
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
# De-dupes variables due to keras tracking issues.
tvars = list(
object_identity.ObjectIdentitySet(model.trainable_variables))
grads = tape.gradient(scaled_loss, tvars)
opt.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for _ in tf.range(steps):
train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),))
if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals: if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals iterations = flags_obj.train_steps // flags_obj.steps_between_evals
...@@ -165,28 +288,54 @@ class TransformerTask(object): ...@@ -165,28 +288,54 @@ class TransformerTask(object):
cased_score_history, uncased_score_history = [], [] cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1): for i in range(1, iterations + 1):
print("Start train iteration:{}/{}".format(i, iterations)) print("Start train iteration:{}/{}".format(i, iterations))
history = model.fit( history = None
train_ds, if params["use_ctl"]:
initial_epoch=i-1, if not self.use_tpu:
epochs=i, raise NotImplementedError(
steps_per_epoch=flags_obj.steps_between_evals, "Custom training loop on GPUs is not implemented.")
callbacks=callbacks, train_steps_per_eval = tf.convert_to_tensor(
# If TimeHistory is enabled, progress bar would be messy. Increase the flags_obj.steps_between_evals, dtype=tf.int32)
# verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1)) # Runs training steps.
train_steps(train_ds_iterator, train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps,
train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.")
history = model.fit(
train_ds,
initial_epoch=i - 1,
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format( print("End train iteration:{}/{} global step:{}".format(
i, i,
iterations, iterations,
i*flags_obj.steps_between_evals)) i*flags_obj.steps_between_evals))
tf.compat.v1.logging.info("Train history: {}".format(history.history))
stats = misc.build_stats(history, callbacks)
if (flags_obj.bleu_source and flags_obj.bleu_ref): if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval() uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score]) cased_score_history.append([i, cased_score])
uncased_score_history.append([i, uncased_score]) uncased_score_history.append([i, uncased_score])
stats = misc.build_stats(history, callbacks) stats = ({
"loss": train_loss
} if history is None else misc.build_stats(history, callbacks))
if uncased_score and cased_score: if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score stats["bleu_uncased"] = uncased_score
stats["bleu_cased"] = cased_score stats["bleu_cased"] = cased_score
...@@ -202,17 +351,18 @@ class TransformerTask(object): ...@@ -202,17 +351,18 @@ class TransformerTask(object):
self.predict_model, self.predict_model,
tf.train.latest_checkpoint(self.flags_obj.model_dir)) tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary() self.predict_model.summary()
return evaluate_and_log_bleu(self.predict_model, return evaluate_and_log_bleu(
self.flags_obj.bleu_source, self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
self.flags_obj.vocab_file) self.distribution_strategy if self.use_tpu else None)
def predict(self): def predict(self):
"""Predicts result from the model.""" """Predicts result from the model."""
params, flags_obj, is_train = self.params, self.flags_obj, False params = self.params
flags_obj = self.flags_obj
with tf.name_scope("model"): with tf.name_scope("model"):
model = transformer.create_model(params, is_train) model = transformer.create_model(params, is_train=False)
self._load_weights_if_possible( self._load_weights_if_possible(
model, tf.train.latest_checkpoint(self.flags_obj.model_dir)) model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
model.summary() model.summary()
...@@ -242,16 +392,28 @@ class TransformerTask(object): ...@@ -242,16 +392,28 @@ class TransformerTask(object):
def _load_weights_if_possible(self, model, init_weight_path=None): def _load_weights_if_possible(self, model, init_weight_path=None):
"""Loads model weights when it is provided.""" """Loads model weights when it is provided."""
if init_weight_path: if init_weight_path:
tf.compat.v1.logging.info("Load weights: {}".format(init_weight_path)) logging.info("Load weights: {}".format(init_weight_path))
model.load_weights(init_weight_path) # TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if self.use_tpu:
checkpoint = tf.train.Checkpoint(
model=model, optimizer=self._create_optimizer())
checkpoint.restore(init_weight_path)
else:
model.load_weights(init_weight_path)
else: else:
print("Weights not loaded from path:{}".format(init_weight_path)) print("Weights not loaded from path:{}".format(init_weight_path))
def _create_optimizer(self): def _create_optimizer(self):
"""Creates optimizer.""" """Creates optimizer."""
params = self.params params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"])
opt = tf.keras.optimizers.Adam( opt = tf.keras.optimizers.Adam(
params["learning_rate"], lr_schedule if self.use_tpu else params["learning_rate"],
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
...@@ -264,25 +426,34 @@ class TransformerTask(object): ...@@ -264,25 +426,34 @@ class TransformerTask(object):
def _ensure_dir(log_dir): def _ensure_dir(log_dir):
"""Makes log dir if not existed.""" """Makes log dir if not existed."""
if not os.path.exists(log_dir): if not tf.io.gfile.exists(log_dir):
os.makedirs(log_dir) tf.io.gfile.makedirs(log_dir)
def main(_): def main(_):
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj): with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj) task = TransformerTask(flags_obj)
if flags_obj.mode == "train":
task.train() def _run_task(task):
elif flags_obj.mode == "predict": if flags_obj.mode == "train":
task.predict() task.train()
elif flags_obj.mode == "eval": elif flags_obj.mode == "predict":
task.eval() task.predict()
elif flags_obj.mode == "eval":
task.eval()
else:
raise ValueError("Invalid mode {}".format(flags_obj.mode))
if not flags_obj.distribution_strategy != "tpu":
_run_task(task)
else: else:
raise ValueError("Invalid mode {}".format(flags_obj.mode)) primary_cpu_task = "/job:worker" if flags_obj.use_tpu_2vm_config else ""
with tf.device(primary_cpu_task):
_run_task(task)
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
misc.define_transformer_flags() misc.define_transformer_flags()
absl_app.run(main) absl_app.run(main)
...@@ -30,7 +30,7 @@ from official.transformer.v2 import misc ...@@ -30,7 +30,7 @@ from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as tm from official.transformer.v2 import transformer_main as tm
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp' FIXED_TIMESTAMP = 'my_time_stamp'
...@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self): def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
def test_train_static_batch(self): def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
if tf.test.is_built_with_cuda():
FLAGS.num_gpus = 1
else:
FLAGS.num_gpus = 0
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
...@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase):
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self): def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
...@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu_fp16(self): def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS(update_flags) FLAGS(update_flags)
def test_predict(self): def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_predict_fp16(self): def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16') self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_eval(self): def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.eval() t.eval()
......
...@@ -18,11 +18,12 @@ from __future__ import absolute_import ...@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import values
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100 _EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4 _BEAM_SIZE = 4
_ALPHA = 0.6 _ALPHA = 0.6
...@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer): ...@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
return subtokenizer.decode(ids) return subtokenizer.decode(ids)
def translate_file( def translate_file(model,
model, subtokenizer, input_file, output_file=None, params,
print_all_translations=True): subtokenizer,
input_file,
output_file=None,
print_all_translations=True,
distribution_strategy=None):
"""Translate lines in file, and save to output file if specified. """Translate lines in file, and save to output file if specified.
Args: Args:
model: Keras model used to generate the translations. model: A Keras model, used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and params: A dictionary, containing the translation related parameters.
translated lines. subtokenizer: A subtokenizer object, used for encoding and decoding source
input_file: file containing lines to translate and translated lines.
output_file: file that stores the generated translations. input_file: A file containing lines to translate.
print_all_translations: If true, all translations are printed to stdout. output_file: A file that stores the generated translations.
print_all_translations: A bool. If true, all translations are printed to
stdout.
distribution_strategy: A distribution strategy, used to perform inference
directly with tf.function instead of Keras model.predict().
Raises: Raises:
ValueError: if output file is invalid. ValueError: if output file is invalid.
""" """
batch_size = _DECODE_BATCH_SIZE batch_size = params["decode_batch_size"]
# Read and sort inputs by length. Keep dictionary (original index-->new index # Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order. # in sorted list) to write translations in the original order.
...@@ -101,24 +110,59 @@ def translate_file( ...@@ -101,24 +110,59 @@ def translate_file(
if j + i * batch_size < total_samples if j + i * batch_size < total_samples
] ]
lines = [_encode_and_add_eos(l, subtokenizer) for l in lines] lines = [_encode_and_add_eos(l, subtokenizer) for l in lines]
if distribution_strategy:
for j in range(batch_size - len(lines)):
lines.append([tokenizer.EOS_ID])
batch = tf.keras.preprocessing.sequence.pad_sequences( batch = tf.keras.preprocessing.sequence.pad_sequences(
lines, dtype="int64", padding="post") lines,
maxlen=params["decode_max_length"],
dtype="int32",
padding="post")
tf.compat.v1.logging.info("Decoding batch %d out of %d.", i, tf.compat.v1.logging.info("Decoding batch %d out of %d.", i,
num_decode_batches) num_decode_batches)
yield batch yield batch
@tf.function
def predict_step(inputs):
"""Decoding step function for TPU runs."""
def _step_fn(inputs):
"""Per replica step function."""
val_outputs, _ = model([inputs], training=False)
return val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
translations = [] translations = []
if distribution_strategy:
num_replicas = distribution_strategy.num_replicas_in_sync
local_batch_size = params["decode_batch_size"] // num_replicas
for i, text in enumerate(input_generator()): for i, text in enumerate(input_generator()):
val_outputs, _ = model.predict(text) if distribution_strategy:
text = np.reshape(text, [num_replicas, local_batch_size, -1])
text = [
tf.convert_to_tensor(per_replica_text) for per_replica_text in text
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
# pylint: enable=protected-access
val_outputs = distribution_strategy.experimental_local_results(
predict_step(text))
val_outputs = np.reshape(
[val_output.numpy() for val_output in val_outputs],
[params["decode_batch_size"], -1])
else:
val_outputs, _ = model.predict(text)
length = len(val_outputs) length = len(val_outputs)
for j in range(length): for j in range(length):
translation = _trim_and_decode(val_outputs[j], subtokenizer) if j + i * batch_size < total_samples:
translations.append(translation) translation = _trim_and_decode(val_outputs[j], subtokenizer)
if print_all_translations: translations.append(translation)
tf.compat.v1.logging.info( if print_all_translations:
"Translating:\n\tInput: %s\n\tOutput: %s" % tf.compat.v1.logging.info(
(sorted_inputs[j + i * batch_size], translation)) "Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[j + i * batch_size], translation))
# Write translations in the order they appeared in the original file. # Write translations in the order they appeared in the original file.
if output_file is not None: if output_file is not None:
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags related to distributed execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def define_distribution(worker_hosts=True, task_index=True):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if worker_hosts:
flags.DEFINE_string(
name='worker_hosts', default=None,
help=help_wrap(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'))
if task_index:
flags.DEFINE_integer(
name='task_index', default=-1,
help=help_wrap('If multi-worker training, the task_index of this '
'worker.'))
return key_flags
...@@ -53,8 +53,8 @@ def get_loss_scale(flags_obj, default_for_fp16): ...@@ -53,8 +53,8 @@ def get_loss_scale(flags_obj, default_for_fp16):
return default_for_fp16 return default_for_fp16
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, def define_performance(num_parallel_calls=False, inter_op=False, intra_op=False,
synthetic_data=True, max_train_steps=True, dtype=True, synthetic_data=True, max_train_steps=False, dtype=True,
all_reduce_alg=True, num_packs=True, all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
......
...@@ -32,6 +32,7 @@ from official.utils.flags import _base ...@@ -32,6 +32,7 @@ from official.utils.flags import _base
from official.utils.flags import _benchmark from official.utils.flags import _benchmark
from official.utils.flags import _conventions from official.utils.flags import _conventions
from official.utils.flags import _device from official.utils.flags import _device
from official.utils.flags import _distribution
from official.utils.flags import _misc from official.utils.flags import _misc
from official.utils.flags import _performance from official.utils.flags import _performance
...@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) ...@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device) define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image) define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance) define_performance = register_key_flags_in_core(_performance.define_performance)
define_distribution = register_key_flags_in_core(
_distribution.define_distribution)
help_wrap = _conventions.help_wrap help_wrap = _conventions.help_wrap
......
...@@ -23,7 +23,9 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -23,7 +23,9 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(num_gpu=False) flags_core.define_base(num_gpu=False)
flags_core.define_performance(dynamic_loss_scale=True, loss_scale=True) flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale=True, loss_scale=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
......
...@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -127,10 +127,7 @@ def get_distribution_strategy(distribution_strategy="default",
return None return None
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
if not tpu_address: # When tpu_address is an empty string, we communicate with local TPUs.
raise ValueError("`tpu_address` must be specified when using "
"TPUStrategy.")
# Initialize TPU System. # Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address) cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver) return tf.distribute.experimental.TPUStrategy(cluster_resolver)
...@@ -205,38 +202,64 @@ class SyntheticDataset(object): ...@@ -205,38 +202,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device.""" """A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1): def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel. # dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'): with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1)) tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor) flat_tensor = tf.nest.flatten(tensor)
variable_data = [] variable_data = []
self._initializers = [] initializers = []
for t in flat_tensor: for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0] rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self.random_name(), v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t) initializer=rebatched_t)
variable_data.append(v) variable_data.append(v)
self._initializers.append(v.initializer) initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data) input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self): def get_next(self):
return self._input_data return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self): def initialize(self):
if tf.executing_eagerly(): if tf.executing_eagerly():
return tf.no_op() return tf.no_op()
else: else:
return self._initializers return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
...@@ -244,22 +267,34 @@ def _monkey_patch_dataset_method(strategy): ...@@ -244,22 +267,34 @@ def _monkey_patch_dataset_method(strategy):
else: else:
return SyntheticDataset(dataset) return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator def make_iterator(self, dataset):
strategy.make_dataset_iterator = make_dataset_iterator dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy): def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'org_make_dataset_iterator'): if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data(): def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
...@@ -267,10 +302,14 @@ def set_up_synthetic_data(): ...@@ -267,10 +302,14 @@ def set_up_synthetic_data():
def undo_set_up_synthetic_data(): def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
......
...@@ -29,7 +29,7 @@ from absl import flags ...@@ -29,7 +29,7 @@ from absl import flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function. extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data. synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if synth: if synth:
args.append("--use_synthetic_data") args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try: try:
flags_core.parse_flags(argv=args) flags_core.parse_flags(argv=args)
main(flags.FLAGS) main(flags.FLAGS)
......
[MESSAGES CONTROL] [MESSAGES CONTROL]
disable=R,W,bad-option-value,trailing-newlines disable=R,W,bad-option-value,trailing-newlines,no-name-in-module
[REPORTS] [REPORTS]
# Tells whether to display a full report or only the messages # Tells whether to display a full report or only the messages
......
This folder contains the Keras implementation of the ResNet models. For more This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../README.md). information about the models, please refer to this [README file](../../README.md).
Similar to the [estimator implementation](/official/resnet), the Keras Similar to the [estimator implementation](../../r1/resnet), the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10 implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version [`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
uses a ResNet50 model implemented in [`resnet_model.py`](./resnet_model.py). uses a ResNet50 model implemented in [`resnet_model.py`](./resnet_model.py).
To use To use
either dataset, make sure that you have the latest version of TensorFlow either dataset, make sure that you have the latest version of TensorFlow
installed and installed and
[add the models folder to your Python path](/official/#running-the-models), [add the models folder to your Python path](/official/#running-the-models),
otherwise you may encounter an error like `ImportError: No module named otherwise you may encounter an error like `ImportError: No module named
official.resnet`. official.resnet`.
## CIFAR-10 ## CIFAR-10
Download and extract the CIFAR-10 data. You can use the following script: Download and extract the CIFAR-10 data. You can use the following script:
```bash ```bash
python cifar10_download_and_extract.py python ../../r1/resnet/cifar10_download_and_extract.py
``` ```
After you download the data, you can run the program by: After you download the data, you can run the program by:
```bash ```bash
python keras_cifar_main.py python resnet_cifar_main.py
``` ```
If you did not use the default directory to download the data, specify the If you did not use the default directory to download the data, specify the
location with the `--data_dir` flag, like: location with the `--data_dir` flag, like:
```bash ```bash
python keras_cifar_main.py --data_dir=/path/to/cifar python resnet_cifar_main.py --data_dir=/path/to/cifar
``` ```
## ImageNet ## ImageNet
Download the ImageNet dataset and convert it to TFRecord format. Download the ImageNet dataset and convert it to TFRecord format.
The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py) The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy) and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
provide a few options. provide a few options.
...@@ -44,57 +44,57 @@ provide a few options. ...@@ -44,57 +44,57 @@ provide a few options.
Once your dataset is ready, you can begin training the model as follows: Once your dataset is ready, you can begin training the model as follows:
```bash ```bash
python keras_imagenet_main.py python resnet_imagenet_main.py
``` ```
Again, if you did not download the data to the default directory, specify the Again, if you did not download the data to the default directory, specify the
location with the `--data_dir` flag: location with the `--data_dir` flag:
```bash ```bash
python keras_imagenet_main.py --data_dir=/path/to/imagenet python resnet_imagenet_main.py --data_dir=/path/to/imagenet
``` ```
There are more flag options you can specify. Here are some examples: There are more flag options you can specify. Here are some examples:
- `--use_synthetic_data`: when set to true, synthetic data, rather than real - `--use_synthetic_data`: when set to true, synthetic data, rather than real
data, are used; data, are used;
- `--batch_size`: the batch size used for the model; - `--batch_size`: the batch size used for the model;
- `--model_dir`: the directory to save the model checkpoint; - `--model_dir`: the directory to save the model checkpoint;
- `--train_epochs`: number of epoches to run for training the model; - `--train_epochs`: number of epoches to run for training the model;
- `--train_steps`: number of steps to run for training the model. We now only - `--train_steps`: number of steps to run for training the model. We now only
support a number that is smaller than the number of batches in an epoch. support a number that is smaller than the number of batches in an epoch.
- `--skip_eval`: when set to true, evaluation as well as validation during - `--skip_eval`: when set to true, evaluation as well as validation during
training is skipped training is skipped
For example, this is a typical command line to run with ImageNet data with For example, this is a typical command line to run with ImageNet data with
batch size 128 per GPU: batch size 128 per GPU:
```bash ```bash
python -m keras_imagenet_main \ python -m resnet_imagenet_main \
--model_dir=/tmp/model_dir/something \ --model_dir=/tmp/model_dir/something \
--num_gpus=2 \ --num_gpus=2 \
--batch_size=128 \ --batch_size=128 \
--train_epochs=90 \ --train_epochs=90 \
--train_steps=10 \ --train_steps=10 \
--use_synthetic_data=false --use_synthetic_data=false
``` ```
See [`keras_common.py`](keras_common.py) for full list of options. See [`common.py`](common.py) for full list of options.
## Using multiple GPUs ## Using multiple GPUs
You can train these models on multiple GPUs using `tf.distribute.Strategy` API. You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
You can read more about them in this You can read more about them in this
[guide](https://www.tensorflow.org/guide/distribute_strategy). [guide](https://www.tensorflow.org/guide/distribute_strategy).
In this example, we have made it easier to use is with just a command line flag In this example, we have made it easier to use is with just a command line flag
`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA, `--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
and 0 otherwise. and 0 otherwise.
- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device. - --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device. - --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous - --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
distributed training across the GPUs. distributed training across the GPUs.
If you wish to run without `tf.distribute.Strategy`, you can do so by setting If you wish to run without `tf.distribute.Strategy`, you can do so by setting
`--distribution_strategy=off`. `--distribution_strategy=off`.
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
HEIGHT = 32 HEIGHT = 32
WIDTH = 32 WIDTH = 32
......
...@@ -20,17 +20,13 @@ from __future__ import print_function ...@@ -20,17 +20,13 @@ from __future__ import print_function
import multiprocessing import multiprocessing
import os import os
import numpy as np
# pylint: disable=g-bad-import-order
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
# pylint: disable=ungrouped-imports
from tensorflow.python.keras.optimizer_v2 import (gradient_descent as
gradient_descent_v2)
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version. BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
...@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
force_v2_in_keras_compile=True) force_v2_in_keras_compile=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_distribution()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
......
...@@ -12,21 +12,21 @@ ...@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for the keras_common module.""" """Tests for the common module."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from mock import Mock from mock import Mock
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from tensorflow.python.platform import googletest
from official.resnet.keras import keras_common from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import common
class KerasCommonTests(tf.test.TestCase): class KerasCommonTests(tf.test.TestCase):
"""Tests for keras_common.""" """Tests for common."""
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
...@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils.BatchTimestamp(1, 2), keras_utils.BatchTimestamp(1, 2),
keras_utils.BatchTimestamp(2, 3)] keras_utils.BatchTimestamp(2, 3)]
th.train_finish_time = 12345 th.train_finish_time = 12345
stats = keras_common.build_stats(history, eval_output, [th]) stats = common.build_stats(history, eval_output, [th])
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
...@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase): ...@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history = self._build_history(1.145, cat_accuracy_sparse=.99988) history = self._build_history(1.145, cat_accuracy_sparse=.99988)
eval_output = self._build_eval_output(.928, 1.9844) eval_output = self._build_eval_output(.928, 1.9844)
stats = keras_common.build_stats(history, eval_output, None) stats = common.build_stats(history, eval_output, None)
self.assertEqual(1.145, stats['loss']) self.assertEqual(1.145, stats['loss'])
self.assertEqual(.99988, stats['training_accuracy_top_1']) self.assertEqual(.99988, stats['training_accuracy_top_1'])
......
...@@ -22,13 +22,13 @@ from absl import app as absl_app ...@@ -22,13 +22,13 @@ from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_cifar_model
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
...@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch, ...@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate. Adjusted learning rate.
""" """
del current_batch, batches_per_epoch # not used del current_batch, batches_per_epoch # not used
initial_learning_rate = keras_common.BASE_LEARNING_RATE * batch_size / 128 initial_learning_rate = common.BASE_LEARNING_RATE * batch_size / 128
learning_rate = initial_learning_rate learning_rate = initial_learning_rate
for mult, start_epoch in LR_SCHEDULE: for mult, start_epoch in LR_SCHEDULE:
if current_epoch >= start_epoch: if current_epoch >= start_epoch:
...@@ -83,8 +83,8 @@ def run(flags_obj): ...@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj) common.set_gpu_thread_mode_and_count(flags_obj)
keras_common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
...@@ -116,7 +116,7 @@ def run(flags_obj): ...@@ -116,7 +116,7 @@ def run(flags_obj):
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT, height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH, width=cifar_preprocessing.WIDTH,
num_channels=cifar_preprocessing.NUM_CHANNELS, num_channels=cifar_preprocessing.NUM_CHANNELS,
...@@ -150,7 +150,7 @@ def run(flags_obj): ...@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn=cifar_preprocessing.parse_record) parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...@@ -171,7 +171,7 @@ def run(flags_obj): ...@@ -171,7 +171,7 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks( callbacks = common.get_callbacks(
learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train']) learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -216,12 +216,12 @@ def run(flags_obj): ...@@ -216,12 +216,12 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
stats = keras_common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
def define_cifar_flags(): def define_cifar_flags():
keras_common.define_keras_flags(dynamic_loss_scale=False) common.define_keras_flags(dynamic_loss_scale=False)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin', flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
......
...@@ -18,17 +18,16 @@ from __future__ import absolute_import ...@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tempfile import mkdtemp import tempfile
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import resnet_cifar_main
class KerasCifarTest(googletest.TestCase): class KerasCifarTest(googletest.TestCase):
...@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase): ...@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def get_temp_dir(self): def get_temp_dir(self):
if not self._tempdir: if not self._tempdir:
self._tempdir = mkdtemp(dir=googletest.GetTempDir()) self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir return self._tempdir
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(KerasCifarTest, cls).setUpClass() super(KerasCifarTest, cls).setUpClass()
keras_cifar_main.define_cifar_flags() resnet_cifar_main.define_cifar_flags()
def setUp(self): def setUp(self):
super(KerasCifarTest, self).setUp() super(KerasCifarTest, self).setUp()
...@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
...@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags = extra_flags + self._extra_flags extra_flags = extra_flags + self._extra_flags
integration.run_synthetic( integration.run_synthetic(
main=keras_cifar_main.run, main=resnet_cifar_main.run,
tmp_root=self.get_temp_dir(), tmp_root=self.get_temp_dir(),
extra_flags=extra_flags extra_flags=extra_flags
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment