Commit ca574c36 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Dataset fixes.

PiperOrigin-RevId: 264174941
parent 3fd8bcfb
...@@ -171,21 +171,22 @@ class TransformerTask(object): ...@@ -171,21 +171,22 @@ class TransformerTask(object):
model.summary() model.summary()
train_ds = data_pipeline.train_input_fn(params)
if self.use_tpu: if self.use_tpu:
if params["is_tpu_pod"]: # Different from experimental_distribute_dataset,
train_ds = ( # experimental_distribute_datasets_from_function requires
self.distribution_strategy # per-replica/local batch size.
.experimental_distribute_datasets_from_function( params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
lambda: data_pipeline.train_input_fn(params))) train_ds = (
else: self.distribution_strategy
train_ds = ( .experimental_distribute_datasets_from_function(
self.distribution_strategy.experimental_distribute_dataset(train_ds) lambda ctx: data_pipeline.train_input_fn(params)))
)
else: else:
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn map_data_fn = data_pipeline.map_data_for_transformer_fn
train_ds = train_ds.map( train_ds = train_ds.map(
map_data_fn, num_parallel_calls=params["num_parallel_calls"]) map_data_fn, num_parallel_calls=params["num_parallel_calls"])
if params["use_ctl"]:
train_ds_iterator = iter(train_ds)
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
...@@ -251,7 +252,7 @@ class TransformerTask(object): ...@@ -251,7 +252,7 @@ class TransformerTask(object):
flags_obj.steps_between_evals, dtype=tf.int32) flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps. # Runs training steps.
train_steps(iter(train_ds), train_steps_per_eval) train_steps(train_ds_iterator, train_steps_per_eval)
train_loss = train_loss_metric.result().numpy().astype(float) train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s", logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps, i * flags_obj.steps_between_evals, flags_obj.train_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment