Commit 901c4cc4 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

Merge remote-tracking branch 'upstream/master' into amp_resnet50

parents ef30de93 824ff2d6
...@@ -42,14 +42,14 @@ from official.utils.logs import mlperf_helper ...@@ -42,14 +42,14 @@ from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def metric_fn(logits, dup_mask, params): def metric_fn(logits, dup_mask, params):
dup_mask = tf.cast(dup_mask, tf.float32) dup_mask = tf.cast(dup_mask, tf.float32)
logits = tf.slice(logits, [0, 0, 1], [-1, -1, -1]) logits = tf.slice(logits, [0, 1], [-1, -1])
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits, logits,
dup_mask, dup_mask,
...@@ -64,12 +64,38 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -64,12 +64,38 @@ class MetricLayer(tf.keras.layers.Layer):
def __init__(self, params): def __init__(self, params):
super(MetricLayer, self).__init__() super(MetricLayer, self).__init__()
self.params = params self.params = params
self.metric = tf.keras.metrics.Mean(name=rconst.HR_METRIC_NAME)
def call(self, inputs): def call(self, inputs, training=False):
logits, dup_mask = inputs logits, dup_mask = inputs
in_top_k, metric_weights = metric_fn(logits, dup_mask, self.params)
self.add_metric(self.metric(in_top_k, sample_weight=metric_weights)) if training:
hr_sum = 0.0
hr_count = 0.0
else:
metric, metric_weights = metric_fn(logits, dup_mask, self.params)
hr_sum = tf.reduce_sum(metric * metric_weights)
hr_count = tf.reduce_sum(metric_weights)
self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
self.add_metric(hr_count, name="hr_count", aggregation="mean")
return logits
class LossLayer(tf.keras.layers.Layer):
"""Pass-through loss layer for NCF model."""
def __init__(self, loss_normalization_factor):
super(LossLayer, self).__init__()
self.loss_normalization_factor = loss_normalization_factor
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="sum")
def call(self, inputs):
logits, labels, valid_pt_mask_input = inputs
loss = self.loss(
y_true=labels, y_pred=logits, sample_weight=valid_pt_mask_input)
loss = loss * (1.0 / self.loss_normalization_factor)
self.add_loss(loss)
return logits return logits
...@@ -122,48 +148,24 @@ def _get_keras_model(params): ...@@ -122,48 +148,24 @@ def _get_keras_model(params):
"""Constructs and returns the model.""" """Constructs and returns the model."""
batch_size = params["batch_size"] batch_size = params["batch_size"]
# The input layers are of shape (1, batch_size), to match the size of the
# input data. The first dimension is needed because the input data are
# required to be batched to use distribution strategies, and in this case, it
# is designed to be of batch_size 1 for each replica.
user_input = tf.keras.layers.Input( user_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=movielens.USER_COLUMN,
dtype=tf.int32)
item_input = tf.keras.layers.Input( item_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=movielens.ITEM_COLUMN, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=movielens.ITEM_COLUMN,
dtype=tf.int32)
valid_pt_mask_input = tf.keras.layers.Input( valid_pt_mask_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=rconst.VALID_POINT_MASK, dtype=tf.bool)
batch_size=params["batches_per_step"],
name=rconst.VALID_POINT_MASK,
dtype=tf.bool)
dup_mask_input = tf.keras.layers.Input( dup_mask_input = tf.keras.layers.Input(
shape=(batch_size,), shape=(1,), name=rconst.DUPLICATE_MASK, dtype=tf.int32)
batch_size=params["batches_per_step"],
name=rconst.DUPLICATE_MASK,
dtype=tf.int32)
label_input = tf.keras.layers.Input( label_input = tf.keras.layers.Input(
shape=(batch_size, 1), shape=(1,), name=rconst.TRAIN_LABEL_KEY, dtype=tf.bool)
batch_size=params["batches_per_step"],
name=rconst.TRAIN_LABEL_KEY,
dtype=tf.bool)
base_model = neumf_model.construct_model(
user_input, item_input, params, need_strip=True)
base_model_output = base_model.output base_model = neumf_model.construct_model(user_input, item_input, params)
logits = tf.keras.layers.Lambda( logits = base_model.output
lambda x: tf.expand_dims(x, 0),
name="logits")(base_model_output)
zeros = tf.keras.layers.Lambda( zeros = tf.keras.layers.Lambda(
lambda x: x * 0)(logits) lambda x: x * 0)(logits)
...@@ -172,9 +174,14 @@ def _get_keras_model(params): ...@@ -172,9 +174,14 @@ def _get_keras_model(params):
[zeros, logits], [zeros, logits],
axis=-1) axis=-1)
"""CTL does metric calculation as part of eval_step function""" # Custom training loop calculates loss and metric as a part of
# training/evaluation step function.
if not params["keras_use_ctl"]: if not params["keras_use_ctl"]:
softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input]) softmax_logits = MetricLayer(params)([softmax_logits, dup_mask_input])
# TODO(b/134744680): Use model.add_loss() instead once the API is well
# supported.
softmax_logits = LossLayer(batch_size)(
[softmax_logits, label_input, valid_pt_mask_input])
keras_model = tf.keras.Model( keras_model = tf.keras.Model(
inputs={ inputs={
...@@ -185,15 +192,6 @@ def _get_keras_model(params): ...@@ -185,15 +192,6 @@ def _get_keras_model(params):
rconst.TRAIN_LABEL_KEY: label_input}, rconst.TRAIN_LABEL_KEY: label_input},
outputs=softmax_logits) outputs=softmax_logits)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction="sum")
keras_model.add_loss(loss_obj(
y_true=label_input,
y_pred=softmax_logits,
sample_weight=valid_pt_mask_input) * 1.0 / batch_size)
keras_model.summary() keras_model.summary()
return keras_model return keras_model
...@@ -207,39 +205,28 @@ def run_ncf(_): ...@@ -207,39 +205,28 @@ def run_ncf(_):
print("Setting tf seed") print("Setting tf seed")
tf.random.set_seed(FLAGS.seed) tf.random.set_seed(FLAGS.seed)
# TODO(seemuch): Support different train and eval batch sizes
if FLAGS.eval_batch_size != FLAGS.batch_size:
logging.warning(
"The Keras implementation of NCF currently does not support batch_size "
"!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
"batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)
)
FLAGS.eval_batch_size = FLAGS.batch_size
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus) num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
if not keras_utils.is_v2_0() and strategy is not None: if not keras_utils.is_v2_0() and strategy is not None:
logging.error("NCF Keras only works with distribution strategy in TF 2.0") logging.error("NCF Keras only works with distribution strategy in TF 2.0")
return return
if (params["keras_use_ctl"] and ( if (params["keras_use_ctl"] and (
not keras_utils.is_v2_0() or strategy is None)): not keras_utils.is_v2_0() or strategy is None)):
logging.error( logging.error(
"Custom training loop only works with tensorflow 2.0 and dist strat.") "Custom training loop only works with tensorflow 2.0 and dist strat.")
return return
if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.")
return
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
# eval). This carries over that rounding to batch_size as well. This is the
# per device batch size
params["batch_size"] = params["eval_batch_size"]
batch_size = params["batch_size"] batch_size = params["batch_size"]
time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(batch_size, FLAGS.log_steps)
callbacks = [time_callback] callbacks = [time_callback]
...@@ -248,8 +235,7 @@ def run_ncf(_): ...@@ -248,8 +235,7 @@ def run_ncf(_):
if generate_input_online: if generate_input_online:
# Start data producing thread. # Start data producing thread.
num_users, num_items, num_train_steps, num_eval_steps, producer = ( num_users, num_items, _, _, producer = ncf_common.get_inputs(params)
ncf_common.get_inputs(params))
producer.start() producer.start()
per_epoch_callback = IncrementEpochCallback(producer) per_epoch_callback = IncrementEpochCallback(producer)
callbacks.append(per_epoch_callback) callbacks.append(per_epoch_callback)
...@@ -261,15 +247,19 @@ def run_ncf(_): ...@@ -261,15 +247,19 @@ def run_ncf(_):
num_items = input_meta_data["num_items"] num_items = input_meta_data["num_items"]
params["num_users"], params["num_items"] = num_users, num_items params["num_users"], params["num_items"] = num_users, num_items
(train_input_dataset, eval_input_dataset, num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data))
steps_per_epoch = None if generate_input_online else num_train_steps
if FLAGS.early_stopping: if FLAGS.early_stopping:
early_stopping_callback = CustomEarlyStopping( early_stopping_callback = CustomEarlyStopping(
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
with tf.device(tpu_lib.get_primary_cpu_task(params["use_tpu"])):
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
keras_model = _get_keras_model(params) keras_model = _get_keras_model(params)
optimizer = tf.keras.optimizers.Adam( optimizer = tf.keras.optimizers.Adam(
...@@ -279,23 +269,108 @@ def run_ncf(_): ...@@ -279,23 +269,108 @@ def run_ncf(_):
epsilon=params["epsilon"]) epsilon=params["epsilon"])
if params["keras_use_ctl"]: if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=generate_input_online)
else:
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
keras_model.compile(
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(
train_input_dataset,
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating")
eval_loss_and_metrics = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
# Keras evaluate() API returns scalar loss and metric values from
# evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
# Format evaluation result into [eval loss, eval hit accuracy].
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def run_ncf_custom_training(params,
strategy,
keras_model,
optimizer,
callbacks,
train_input_dataset,
eval_input_dataset,
num_train_steps,
num_eval_steps,
generate_input_online=True):
"""Runs custom training loop.
Args:
params: Dictionary containing training parameters.
strategy: Distribution strategy to be used for distributed training.
keras_model: Model used for training.
optimizer: Optimizer used for training.
callbacks: Callbacks to be invoked between batches/epochs.
train_input_dataset: tf.data.Dataset used for training.
eval_input_dataset: tf.data.Dataset used for evaluation.
num_train_steps: Total number of steps to run for training.
num_eval_steps: Total number of steps to run for evaluation.
generate_input_online: Whether input data was generated by data producer.
When data is generated by data producer, then train dataset must be
re-initialized after every epoch.
Returns:
A tuple of train loss and a list of training and evaluation results.
"""
loss_object = tf.keras.losses.SparseCategoricalCrossentropy( loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
reduction="sum", reduction="sum", from_logits=True)
from_logits=True) train_input_iterator = iter(
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset) strategy.experimental_distribute_dataset(train_input_dataset))
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
@tf.function def train_step(train_iterator):
def train_step():
"""Called once per step to train the model.""" """Called once per step to train the model."""
def step_fn(features): def step_fn(features):
"""Computes loss and applied gradient per replica.""" """Computes loss and applied gradient per replica."""
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
softmax_logits = keras_model(features) softmax_logits = keras_model(features)
labels = features[rconst.TRAIN_LABEL_KEY] labels = features[rconst.TRAIN_LABEL_KEY]
loss = loss_object(labels, softmax_logits, loss = loss_object(
labels,
softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK]) sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync)) loss *= (1.0 / params["batch_size"])
grads = tape.gradient(loss, keras_model.trainable_variables) grads = tape.gradient(loss, keras_model.trainable_variables)
# Converting gradients to dense form helps in perf on GPU for NCF # Converting gradients to dense form helps in perf on GPU for NCF
...@@ -304,33 +379,42 @@ def run_ncf(_): ...@@ -304,33 +379,42 @@ def run_ncf(_):
optimizer.apply_gradients(grads) optimizer.apply_gradients(grads)
return loss return loss
per_replica_losses = strategy.experimental_run(step_fn, per_replica_losses = strategy.experimental_run_v2(
train_input_iterator) step_fn, args=(next(train_iterator),))
mean_loss = strategy.reduce( mean_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return mean_loss return mean_loss
@tf.function def eval_step(eval_iterator):
def eval_step():
"""Called once per eval step to compute eval metrics.""" """Called once per eval step to compute eval metrics."""
def step_fn(features): def step_fn(features):
"""Computes eval metrics per replica.""" """Computes eval metrics per replica."""
softmax_logits = keras_model(features) softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn( in_top_k, metric_weights = metric_fn(softmax_logits,
softmax_logits, features[rconst.DUPLICATE_MASK], params) features[rconst.DUPLICATE_MASK],
hr_sum = tf.reduce_sum(in_top_k*metric_weights) params)
hr_sum = tf.reduce_sum(in_top_k * metric_weights)
hr_count = tf.reduce_sum(metric_weights) hr_count = tf.reduce_sum(metric_weights)
return hr_sum, hr_count return hr_sum, hr_count
per_replica_hr_sum, per_replica_hr_count = ( per_replica_hr_sum, per_replica_hr_count = (
strategy.experimental_run(step_fn, eval_input_iterator)) strategy.experimental_run_v2(
step_fn, args=(next(eval_iterator),)))
hr_sum = strategy.reduce( hr_sum = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_sum, axis=None)
hr_count = strategy.reduce( hr_count = strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None) tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
return hr_sum, hr_count return hr_sum, hr_count
time_callback.on_train_begin() if not FLAGS.run_eagerly:
train_step = tf.function(train_step)
eval_step = tf.function(eval_step)
for callback in callbacks:
callback.on_train_begin()
train_loss = 0
for epoch in range(FLAGS.train_epochs): for epoch in range(FLAGS.train_epochs):
for cb in callbacks: for cb in callbacks:
cb.on_epoch_begin(epoch) cb.on_epoch_begin(epoch)
...@@ -341,68 +425,43 @@ def run_ncf(_): ...@@ -341,68 +425,43 @@ def run_ncf(_):
# contains all epoch worth of data. Thus we do not need # contains all epoch worth of data. Thus we do not need
# to initialize dataset when reading from tf record files. # to initialize dataset when reading from tf record files.
if generate_input_online: if generate_input_online:
train_input_iterator.initialize() train_input_iterator = iter(
strategy.experimental_distribute_dataset(train_input_dataset))
train_loss = 0 train_loss = 0
for step in range(num_train_steps): for step in range(num_train_steps):
time_callback.on_batch_begin(step+epoch*num_train_steps) current_step = step + epoch * num_train_steps
train_loss += train_step() for c in callbacks:
time_callback.on_batch_end(step+epoch*num_train_steps) c.on_batch_begin(current_step)
train_loss += train_step(train_input_iterator)
for c in callbacks:
c.on_batch_end(current_step)
train_loss /= num_train_steps train_loss /= num_train_steps
logging.info("Done training epoch %s, epoch loss=%s.", logging.info("Done training epoch %s, epoch loss=%s.", epoch + 1,
epoch+1, train_loss) train_loss)
eval_input_iterator.initialize()
eval_input_iterator = iter(
strategy.experimental_distribute_dataset(eval_input_dataset))
hr_sum = 0 hr_sum = 0
hr_count = 0 hr_count = 0
for _ in range(num_eval_steps): for _ in range(num_eval_steps):
step_hr_sum, step_hr_count = eval_step() step_hr_sum, step_hr_count = eval_step(eval_input_iterator)
hr_sum += step_hr_sum hr_sum += step_hr_sum
hr_count += step_hr_count hr_count += step_hr_count
logging.info("Done eval epoch %s, hr=%s.", epoch+1, hr_sum/hr_count)
logging.info("Done eval epoch %s, hr=%s.", epoch + 1, hr_sum / hr_count)
if (FLAGS.early_stopping and if (FLAGS.early_stopping and
float(hr_sum/hr_count) > params["hr_threshold"]): float(hr_sum / hr_count) > params["hr_threshold"]):
break break
time_callback.on_train_end() for c in callbacks:
eval_results = [None, hr_sum/hr_count] c.on_train_end()
else:
with distribution_utils.get_strategy_scope(strategy):
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
if FLAGS.force_v2_in_keras_compile is not None:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
keras_model.compile(
optimizer=optimizer,
run_eagerly=FLAGS.run_eagerly)
history = keras_model.fit(
train_input_dataset,
epochs=FLAGS.train_epochs,
steps_per_epoch=steps_per_epoch,
callbacks=callbacks,
validation_data=eval_input_dataset,
validation_steps=num_eval_steps,
verbose=2)
logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
if history and history.history: return train_loss, [None, hr_sum / hr_count]
train_history = history.history
train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback)
return stats
def build_stats(loss, eval_result, time_callback): def build_stats(loss, eval_result, time_callback):
...@@ -442,8 +501,6 @@ def main(_): ...@@ -442,8 +501,6 @@ def main(_):
with logger.benchmark_context(FLAGS), \ with logger.benchmark_context(FLAGS), \
mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging): mlperf_helper.LOGGER(FLAGS.output_ml_perf_compliance_logging):
mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0]) mlperf_helper.set_ncf_root(os.path.split(os.path.abspath(__file__))[0])
if FLAGS.tpu:
raise ValueError("NCF in Keras does not support TPU for now")
run_ncf(FLAGS) run_ncf(FLAGS)
......
...@@ -189,26 +189,26 @@ class NcfTest(tf.test.TestCase): ...@@ -189,26 +189,26 @@ class NcfTest(tf.test.TestCase):
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
_BASE_END_TO_END_FLAGS = ['-batch_size', '1024', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase):
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase): ...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase):
['-num_gpus', '0'] + ['-num_gpus', '0'] +
['-keras_use_ctl', 'True']) ['-keras_use_ctl', 'True'])
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase): ...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase):
format(1, context.num_gpus())) format(1, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase): ...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase):
format(2, context.num_gpus())) format(2, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2'])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -109,7 +109,6 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -109,7 +109,6 @@ def neumf_model_fn(features, labels, mode, params):
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.OPT_HP_ADAM_EPSILON, mlperf_helper.ncf_print(key=mlperf_helper.TAGS.OPT_HP_ADAM_EPSILON,
value=params["epsilon"]) value=params["epsilon"])
optimizer = tf.compat.v1.train.AdamOptimizer( optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=params["learning_rate"], learning_rate=params["learning_rate"],
beta1=params["beta1"], beta1=params["beta1"],
...@@ -151,7 +150,7 @@ def _strip_first_and_last_dimension(x, batch_size): ...@@ -151,7 +150,7 @@ def _strip_first_and_last_dimension(x, batch_size):
return tf.reshape(x[0, :], (batch_size,)) return tf.reshape(x[0, :], (batch_size,))
def construct_model(user_input, item_input, params, need_strip=False): def construct_model(user_input, item_input, params):
# type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model # type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model
"""Initialize NeuMF model. """Initialize NeuMF model.
...@@ -184,34 +183,33 @@ def construct_model(user_input, item_input, params, need_strip=False): ...@@ -184,34 +183,33 @@ def construct_model(user_input, item_input, params, need_strip=False):
# Initializer for embedding layers # Initializer for embedding layers
embedding_initializer = "glorot_uniform" embedding_initializer = "glorot_uniform"
if need_strip: def mf_slice_fn(x):
batch_size = params["batch_size"] x = tf.squeeze(x, [1])
return x[:, :mf_dim]
user_input_reshaped = tf.keras.layers.Lambda(
lambda x: _strip_first_and_last_dimension(
x, batch_size))(user_input)
item_input_reshaped = tf.keras.layers.Lambda( def mlp_slice_fn(x):
lambda x: _strip_first_and_last_dimension( x = tf.squeeze(x, [1])
x, batch_size))(item_input) return x[:, mf_dim:]
# It turns out to be significantly more effecient to store the MF and MLP # It turns out to be significantly more effecient to store the MF and MLP
# embedding portions in the same table, and then slice as needed. # embedding portions in the same table, and then slice as needed.
mf_slice_fn = lambda x: x[:, :mf_dim]
mlp_slice_fn = lambda x: x[:, mf_dim:]
embedding_user = tf.keras.layers.Embedding( embedding_user = tf.keras.layers.Embedding(
num_users, mf_dim + model_layers[0] // 2, num_users,
mf_dim + model_layers[0] // 2,
embeddings_initializer=embedding_initializer, embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization), embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1, name="embedding_user")( input_length=1,
user_input_reshaped if need_strip else user_input) name="embedding_user")(
user_input)
embedding_item = tf.keras.layers.Embedding( embedding_item = tf.keras.layers.Embedding(
num_items, mf_dim + model_layers[0] // 2, num_items,
mf_dim + model_layers[0] // 2,
embeddings_initializer=embedding_initializer, embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization), embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1, name="embedding_item")( input_length=1,
item_input_reshaped if need_strip else item_input) name="embedding_item")(
item_input)
# GMF part # GMF part
mf_user_latent = tf.keras.layers.Lambda( mf_user_latent = tf.keras.layers.Lambda(
......
# ResNet in TensorFlow # ResNet in TensorFlow
* For the Keras version of the ResNet model, see * For the Keras version of the ResNet model, see
[`official/resnet/keras`](keras). [`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see * For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl). [`official/resnet/ctl`](ctl).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet). * For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -283,4 +283,6 @@ if __name__ == '__main__': ...@@ -283,4 +283,6 @@ if __name__ == '__main__':
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
keras_common.define_keras_flags() keras_common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main) absl_app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common as keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_cifar_main as keras_cifar_main
from official.vision.image_classification import resnet_cifar_model
from official.vision.image_classification import resnet_imagenet_main as keras_imagenet_main
from official.vision.image_classification import resnet_model
del absolute_import
del division
del print_function
...@@ -208,21 +208,6 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase): ...@@ -208,21 +208,6 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase):
FLAGS.model_dir = '' FLAGS.model_dir = ''
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla.
This is test is for accuracy not scaling. The batch-size is not scaled to
the number of gpus.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.enable_xla = True
self._run_and_report_benchmark()
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase): class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests.""" """Benchmark accuracy tests."""
......
...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params): ...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params):
label_value=1, label_value=1,
label_dtype=tf.int64, label_dtype=tf.int64,
) )
return dataset.batch(batch) return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params): def train_input_fn(params):
......
...@@ -176,6 +176,21 @@ def define_transformer_flags(): ...@@ -176,6 +176,21 @@ def define_transformer_flags():
flags.DEFINE_string( flags.DEFINE_string(
name='mode', default='train', name='mode', default='train',
help=flags_core.help_wrap('mode: train, eval, or predict')) help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
...@@ -216,8 +231,6 @@ def define_transformer_flags(): ...@@ -216,8 +231,6 @@ def define_transformer_flags():
return True return True
# pylint: enable=unused-variable # pylint: enable=unused-variable
flags_core.require_cloud_storage(['data_dir', 'model_dir', 'export_dir'])
def get_callbacks(): def get_callbacks():
"""Returns common callbacks.""" """Returns common callbacks."""
......
...@@ -23,6 +23,51 @@ import tensorflow as tf ...@@ -23,6 +23,51 @@ import tensorflow as tf
K = tf.keras.backend K = tf.keras.backend
class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule."""
def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
hidden_size: An integer, the model dimension in the hidden layers.
warmup_steps: An integer, the number of steps required for linear warmup.
"""
super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay.
Args:
global_step: An integer, the current global step used for learning rate
calculation.
Returns:
A float, the learning rate needs to be used for current global step.
"""
with tf.name_scope('learning_rate_schedule'):
global_step = tf.cast(global_step, tf.float32)
learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
# Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
'initial_learning_rate': self.initial_learning_rate,
'hidden_size': self.hidden_size,
'warmup_steps': self.warmup_steps,
}
class LearningRateFn(object): class LearningRateFn(object):
"""Creates learning rate function.""" """Creates learning rate function."""
......
...@@ -27,12 +27,16 @@ import tempfile ...@@ -27,12 +27,16 @@ import tempfile
from absl import app as absl_app # pylint: disable=unused-import from absl import app as absl_app # pylint: disable=unused-import
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import object_identity
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.transformer.v2 import data_pipeline from official.transformer.v2 import data_pipeline
from official.transformer.v2 import metrics
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import optimizer from official.transformer.v2 import optimizer
from official.transformer.v2 import transformer from official.transformer.v2 import transformer
...@@ -75,8 +79,8 @@ def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file): ...@@ -75,8 +79,8 @@ def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file):
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
model, subtokenizer, bleu_source, bleu_ref) model, subtokenizer, bleu_source, bleu_ref)
tf.compat.v1.logging.info("Bleu score (uncased): %s", uncased_score) logging.info("Bleu score (uncased): %s", uncased_score)
tf.compat.v1.logging.info("Bleu score (cased): %s", cased_score) logging.info("Bleu score (cased): %s", cased_score)
return uncased_score, cased_score return uncased_score, cased_score
...@@ -88,26 +92,20 @@ class TransformerTask(object): ...@@ -88,26 +92,20 @@ class TransformerTask(object):
Args: Args:
flags_obj: Object containing parsed flag values, i.e., FLAGS. flags_obj: Object containing parsed flag values, i.e., FLAGS.
Raises:
ValueError: if not using static batch for input data on TPU.
""" """
self.flags_obj = flags_obj self.flags_obj = flags_obj
self.predict_model = None self.predict_model = None
# Add flag-defined parameters to params object # Add flag-defined parameters to params object
num_gpus = flags_core.get_num_gpus(flags_obj) num_gpus = flags_core.get_num_gpus(flags_obj)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus) self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
...@@ -130,33 +128,113 @@ class TransformerTask(object): ...@@ -130,33 +128,113 @@ class TransformerTask(object):
"infer_float32_vars") "infer_float32_vars")
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else:
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
@property
def use_tpu(self):
if self.distribution_strategy:
return isinstance(self.distribution_strategy,
tf.distribute.experimental.TPUStrategy)
return False
def train(self): def train(self):
"""Trains the model.""" """Trains the model."""
params, flags_obj, is_train = self.params, self.flags_obj, True params = self.params
flags_obj = self.flags_obj
# Sets config options. # Sets config options.
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
if self.distribution_strategy: with distribution_utils.get_strategy_scope(self.distribution_strategy):
with self.distribution_strategy.scope(): model = transformer.create_model(params, is_train=True)
model = transformer.create_model(params, is_train)
opt = self._create_optimizer() opt = self._create_optimizer()
model.compile(opt) if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
else: else:
model = transformer.create_model(params, is_train)
opt = self._create_optimizer()
model.compile(opt) model.compile(opt)
model.summary() model.summary()
if self.use_tpu:
# Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires
# per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params)))
else:
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(map_data_fn, train_ds = train_ds.map(
num_parallel_calls=params["num_parallel_calls"]) map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
"""Training steps function for TPU runs.
Args:
iterator: The input iterator of the training dataset.
steps: An integer, the number of training steps.
Returns:
A float, the loss value.
"""
def _step_fn(inputs):
"""Per-replica step function."""
inputs, targets = inputs
with tf.GradientTape() as tape:
logits = model([inputs, targets], training=True)
loss = metrics.transformer_loss(logits, targets,
params["label_smoothing"],
params["vocab_size"])
# Scales the loss, which results in using the average loss across all
# of the replicas for backprop.
scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
# De-dupes variables due to keras tracking issues.
tvars = list(
object_identity.ObjectIdentitySet(model.trainable_variables))
grads = tape.gradient(scaled_loss, tvars)
opt.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for _ in tf.range(steps):
train_loss_metric.reset_states()
self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),))
if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals: if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals iterations = flags_obj.train_steps // flags_obj.steps_between_evals
...@@ -165,28 +243,54 @@ class TransformerTask(object): ...@@ -165,28 +243,54 @@ class TransformerTask(object):
cased_score_history, uncased_score_history = [], [] cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1): for i in range(1, iterations + 1):
print("Start train iteration:{}/{}".format(i, iterations)) print("Start train iteration:{}/{}".format(i, iterations))
history = None
if params["use_ctl"]:
if not self.use_tpu:
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")
train_steps_per_eval = tf.convert_to_tensor(
flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps.
train_steps(train_ds_iterator, train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps,
train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.")
history = model.fit( history = model.fit(
train_ds, train_ds,
initial_epoch=i-1, initial_epoch=i - 1,
epochs=i, epochs=i,
steps_per_epoch=flags_obj.steps_between_evals, steps_per_epoch=flags_obj.steps_between_evals,
callbacks=callbacks, callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase the # If TimeHistory is enabled, progress bar would be messy. Increase
# verbose level to get rid of it. # the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1)) verbose=(2 if flags_obj.enable_time_history else 1))
logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format( print("End train iteration:{}/{} global step:{}".format(
i, i,
iterations, iterations,
i*flags_obj.steps_between_evals)) i*flags_obj.steps_between_evals))
tf.compat.v1.logging.info("Train history: {}".format(history.history))
stats = misc.build_stats(history, callbacks)
if (flags_obj.bleu_source and flags_obj.bleu_ref): if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval() uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score]) cased_score_history.append([i, cased_score])
uncased_score_history.append([i, uncased_score]) uncased_score_history.append([i, uncased_score])
stats = misc.build_stats(history, callbacks) stats = ({
"loss": train_loss
} if history is None else misc.build_stats(history, callbacks))
if uncased_score and cased_score: if uncased_score and cased_score:
stats["bleu_uncased"] = uncased_score stats["bleu_uncased"] = uncased_score
stats["bleu_cased"] = cased_score stats["bleu_cased"] = cased_score
...@@ -209,10 +313,11 @@ class TransformerTask(object): ...@@ -209,10 +313,11 @@ class TransformerTask(object):
def predict(self): def predict(self):
"""Predicts result from the model.""" """Predicts result from the model."""
params, flags_obj, is_train = self.params, self.flags_obj, False params = self.params
flags_obj = self.flags_obj
with tf.name_scope("model"): with tf.name_scope("model"):
model = transformer.create_model(params, is_train) model = transformer.create_model(params, is_train=False)
self._load_weights_if_possible( self._load_weights_if_possible(
model, tf.train.latest_checkpoint(self.flags_obj.model_dir)) model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
model.summary() model.summary()
...@@ -242,7 +347,14 @@ class TransformerTask(object): ...@@ -242,7 +347,14 @@ class TransformerTask(object):
def _load_weights_if_possible(self, model, init_weight_path=None): def _load_weights_if_possible(self, model, init_weight_path=None):
"""Loads model weights when it is provided.""" """Loads model weights when it is provided."""
if init_weight_path: if init_weight_path:
tf.compat.v1.logging.info("Load weights: {}".format(init_weight_path)) logging.info("Load weights: {}".format(init_weight_path))
# TODO(b/139414977): Having the same variable restoring method for both
# TPU and GPU.
if self.use_tpu:
checkpoint = tf.train.Checkpoint(
model=model, optimizer=self._create_optimizer())
checkpoint.restore(init_weight_path)
else:
model.load_weights(init_weight_path) model.load_weights(init_weight_path)
else: else:
print("Weights not loaded from path:{}".format(init_weight_path)) print("Weights not loaded from path:{}".format(init_weight_path))
...@@ -250,8 +362,13 @@ class TransformerTask(object): ...@@ -250,8 +362,13 @@ class TransformerTask(object):
def _create_optimizer(self): def _create_optimizer(self):
"""Creates optimizer.""" """Creates optimizer."""
params = self.params params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"])
opt = tf.keras.optimizers.Adam( opt = tf.keras.optimizers.Adam(
params["learning_rate"], lr_schedule if self.use_tpu else params["learning_rate"],
params["optimizer_adam_beta1"], params["optimizer_adam_beta1"],
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
...@@ -264,14 +381,16 @@ class TransformerTask(object): ...@@ -264,14 +381,16 @@ class TransformerTask(object):
def _ensure_dir(log_dir): def _ensure_dir(log_dir):
"""Makes log dir if not existed.""" """Makes log dir if not existed."""
if not os.path.exists(log_dir): if not tf.io.gfile.exists(log_dir):
os.makedirs(log_dir) tf.io.gfile.makedirs(log_dir)
def main(_): def main(_):
flags_obj = flags.FLAGS flags_obj = flags.FLAGS
with logger.benchmark_context(flags_obj): with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj) task = TransformerTask(flags_obj)
def _run_task(task):
if flags_obj.mode == "train": if flags_obj.mode == "train":
task.train() task.train()
elif flags_obj.mode == "predict": elif flags_obj.mode == "predict":
...@@ -281,8 +400,15 @@ def main(_): ...@@ -281,8 +400,15 @@ def main(_):
else: else:
raise ValueError("Invalid mode {}".format(flags_obj.mode)) raise ValueError("Invalid mode {}".format(flags_obj.mode))
if not flags_obj.distribution_strategy != "tpu":
_run_task(task)
else:
primary_cpu_task = "/job:worker" if flags_obj.use_tpu_2vm_config else ""
with tf.device(primary_cpu_task):
_run_task(task)
if __name__ == "__main__": if __name__ == "__main__":
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
misc.define_transformer_flags() misc.define_transformer_flags()
absl_app.run(main) absl_app.run(main)
...@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -80,11 +80,19 @@ class TransformerTaskTest(tf.test.TestCase):
self.assertTrue(os.path.exists(filepath)) self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self): def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
def test_train_static_batch(self): def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
if tf.test.is_built_with_cuda():
FLAGS.num_gpus = 1
else:
FLAGS.num_gpus = 0
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
...@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -97,6 +105,7 @@ class TransformerTaskTest(tf.test.TestCase):
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self): def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
...@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -105,8 +114,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self): def test_train_2_gpu(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -117,8 +126,8 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu_fp16(self): def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'. '{} GPUs are not available for this test. {} GPUs are available'
format(2, context.num_gpus())) .format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
...@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -153,16 +162,22 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS(update_flags) FLAGS(update_flags)
def test_predict(self): def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_predict_fp16(self): def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16') self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.predict() t.predict()
def test_eval(self): def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.eval() t.eval()
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags related to distributed execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap
def define_distribution(worker_hosts=True, task_index=True):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if worker_hosts:
flags.DEFINE_string(
name='worker_hosts', default=None,
help=help_wrap(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'))
if task_index:
flags.DEFINE_integer(
name='task_index', default=-1,
help=help_wrap('If multi-worker training, the task_index of this '
'worker.'))
return key_flags
...@@ -54,7 +54,7 @@ def get_loss_scale(flags_obj, default_for_fp16): ...@@ -54,7 +54,7 @@ def get_loss_scale(flags_obj, default_for_fp16):
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True, synthetic_data=True, max_train_steps=False, dtype=True,
all_reduce_alg=True, num_packs=True, all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
......
...@@ -32,6 +32,7 @@ from official.utils.flags import _base ...@@ -32,6 +32,7 @@ from official.utils.flags import _base
from official.utils.flags import _benchmark from official.utils.flags import _benchmark
from official.utils.flags import _conventions from official.utils.flags import _conventions
from official.utils.flags import _device from official.utils.flags import _device
from official.utils.flags import _distribution
from official.utils.flags import _misc from official.utils.flags import _misc
from official.utils.flags import _performance from official.utils.flags import _performance
...@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) ...@@ -77,6 +78,8 @@ define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device) define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image) define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance) define_performance = register_key_flags_in_core(_performance.define_performance)
define_distribution = register_key_flags_in_core(
_distribution.define_distribution)
help_wrap = _conventions.help_wrap help_wrap = _conventions.help_wrap
......
...@@ -24,6 +24,8 @@ import random ...@@ -24,6 +24,8 @@ import random
import string import string
import tensorflow as tf import tensorflow as tf
from official.utils.misc import tpu_lib
def _collective_communication(all_reduce_alg): def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg. """Return a CollectiveCommunication based on all_reduce_alg.
...@@ -83,16 +85,18 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -83,16 +85,18 @@ def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
num_workers=1, num_workers=1,
all_reduce_alg=None, all_reduce_alg=None,
num_packs=1): num_packs=1,
tpu_address=None):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
distribution_strategy: a string specifying which distribution strategy to distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are 'off', 'default', 'one_device', 'mirrored', use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
not to use Distribution Strategy; 'default' means to choose from 'off' means not to use Distribution Strategy; 'default' means to choose from
`MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy` `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
according to the number of GPUs and number of workers. according to the number of GPUs and number of workers. 'tpu' means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model. num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
...@@ -102,12 +106,14 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -102,12 +106,14 @@ def get_distribution_strategy(distribution_strategy="default",
device topology. device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce` num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`. or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not
be None if `distribution_strategy` is set to `tpu`.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
Raises: Raises:
ValueError: if `distribution_strategy` is 'off' or 'one_device' and ValueError: if `distribution_strategy` is 'off' or 'one_device' and
`num_gpus` is larger than 1; or `num_gpus` is negative. `num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
""" """
if num_gpus < 0: if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.") raise ValueError("`num_gpus` can not be negative.")
...@@ -120,6 +126,12 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -120,6 +126,12 @@ def get_distribution_strategy(distribution_strategy="default",
"flag cannot be set to 'off'.".format(num_gpus, num_workers)) "flag cannot be set to 'off'.".format(num_gpus, num_workers))
return None return None
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(tpu_address)
return tf.distribute.experimental.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored": if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg)) communication=_collective_communication(all_reduce_alg))
...@@ -190,38 +202,64 @@ class SyntheticDataset(object): ...@@ -190,38 +202,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device.""" """A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1): def __init__(self, dataset, split_by=1):
self._input_data = {}
# dataset.take(1) doesn't have GPU kernel. # dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'): with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1)) tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor) flat_tensor = tf.nest.flatten(tensor)
variable_data = [] variable_data = []
self._initializers = [] initializers = []
for t in flat_tensor: for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0] rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self.random_name(), v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t) initializer=rebatched_t)
variable_data.append(v) variable_data.append(v)
self._initializers.append(v.initializer) initializers.append(v.initializer)
self._input_data = tf.nest.pack_sequence_as(tensor, variable_data) input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self): def get_next(self):
return self._input_data return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self): def initialize(self):
if tf.executing_eagerly(): if tf.executing_eagerly():
return tf.no_op() return tf.no_op()
else: else:
return self._initializers return self._initializers
def random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def _monkey_patch_dataset_method(strategy): def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method.""" """Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset_iterator(self, dataset): def make_dataset(self, dataset):
tf.compat.v1.logging.info('Using pure synthetic data.') tf.compat.v1.logging.info('Using pure synthetic data.')
with self.scope(): with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access if self.extended._global_batch_size: # pylint: disable=protected-access
...@@ -229,22 +267,34 @@ def _monkey_patch_dataset_method(strategy): ...@@ -229,22 +267,34 @@ def _monkey_patch_dataset_method(strategy):
else: else:
return SyntheticDataset(dataset) return SyntheticDataset(dataset)
strategy.org_make_dataset_iterator = strategy.make_dataset_iterator def make_iterator(self, dataset):
strategy.make_dataset_iterator = make_dataset_iterator dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy): def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'org_make_dataset_iterator'): if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.org_make_dataset_iterator strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data(): def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
...@@ -252,10 +302,14 @@ def set_up_synthetic_data(): ...@@ -252,10 +302,14 @@ def set_up_synthetic_data():
def undo_set_up_synthetic_data(): def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core. # TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'): if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else: else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
......
...@@ -31,3 +31,8 @@ def tpu_initialize(tpu_address): ...@@ -31,3 +31,8 @@ def tpu_initialize(tpu_address):
tf.config.experimental_connect_to_host(cluster_resolver.master()) tf.config.experimental_connect_to_host(cluster_resolver.master())
tf.tpu.experimental.initialize_tpu_system(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver return cluster_resolver
def get_primary_cpu_task(use_remote_tpu=False):
"""Returns remote TPU worker address. No-op for GPU/CPU training."""
return "/job:worker" if use_remote_tpu else ""
...@@ -29,7 +29,7 @@ from absl import flags ...@@ -29,7 +29,7 @@ from absl import flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function. extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data. synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if synth: if synth:
args.append("--use_synthetic_data") args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try: try:
flags_core.parse_flags(argv=args) flags_core.parse_flags(argv=args)
main(flags.FLAGS) main(flags.FLAGS)
......
This folder contains the Keras implementation of the ResNet models. For more This folder contains the Keras implementation of the ResNet models. For more
information about the models, please refer to this [README file](../README.md). information about the models, please refer to this [README file](../../README.md).
Similar to the [estimator implementation](/official/resnet), the Keras Similar to the [estimator implementation](../../r1/resnet), the Keras
implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10 implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10
version uses a ResNet56 model implemented in version uses a ResNet56 model implemented in
[`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version [`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment