Commit ca574c36 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Dataset fixes.

PiperOrigin-RevId: 264174941
parent 3fd8bcfb
......@@ -171,21 +171,22 @@ class TransformerTask(object):
model.summary()
train_ds = data_pipeline.train_input_fn(params)
if self.use_tpu:
if params["is_tpu_pod"]:
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda: data_pipeline.train_input_fn(params)))
else:
train_ds = (
self.distribution_strategy.experimental_distribute_dataset(train_ds)
)
# Different from experimental_distribute_dataset,
# experimental_distribute_datasets_from_function requires
# per-replica/local batch size.
params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params)))
else:
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
......@@ -251,7 +252,7 @@ class TransformerTask(object):
flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps.
train_steps(iter(train_ds), train_steps_per_eval)
train_steps(train_ds_iterator, train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment