Commit f5014889 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312194218
parent 25160730
...@@ -51,7 +51,6 @@ from __future__ import absolute_import ...@@ -51,7 +51,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import os import os
from absl import logging from absl import logging
...@@ -157,7 +156,7 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -157,7 +156,7 @@ def _batch_examples(dataset, batch_size, max_length):
# Create list of batch sizes for each bucket_id, so that # Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size # bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes = [batch_size // x for x in buckets_max] bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
# bucket_id will be a tensor, so convert this list to a tensor as well. # bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64) bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
...@@ -270,7 +269,8 @@ def _read_and_batch_from_files( ...@@ -270,7 +269,8 @@ def _read_and_batch_from_files(
def _generate_synthetic_data(params): def _generate_synthetic_data(params):
"""Create synthetic data based on the parameter batch size.""" """Create synthetic data based on the parameter batch size."""
batch = length = int(math.sqrt(params["batch_size"])) batch_size = int(params["batch_size"] // params["max_length"])
length = params["max_length"]
dataset = model_helpers.generate_synthetic_data( dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([length]), input_shape=tf.TensorShape([length]),
input_value=1, input_value=1,
...@@ -279,7 +279,11 @@ def _generate_synthetic_data(params): ...@@ -279,7 +279,11 @@ def _generate_synthetic_data(params):
label_value=1, label_value=1,
label_dtype=tf.int64, label_dtype=tf.int64,
) )
return dataset.batch(batch, drop_remainder=True) if params["static_batch"]:
dataset = dataset.batch(batch_size, drop_remainder=True)
else:
dataset = dataset.padded_batch(batch_size, ([None], [None]))
return dataset
def train_input_fn(params, ctx=None): def train_input_fn(params, ctx=None):
......
...@@ -168,8 +168,6 @@ class TransformerTask(object): ...@@ -168,8 +168,6 @@ class TransformerTask(object):
tpu_address=flags_obj.tpu or "") tpu_address=flags_obj.tpu or "")
if self.use_tpu: if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else: else:
logging.info("Running transformer with num_gpus = %d", num_gpus) logging.info("Running transformer with num_gpus = %d", num_gpus)
......
...@@ -61,6 +61,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -61,6 +61,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.train_steps = 2 FLAGS.train_steps = 2
FLAGS.validation_steps = 1 FLAGS.validation_steps = 1
FLAGS.batch_size = 8 FLAGS.batch_size = 8
FLAGS.max_length = 1
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off' FLAGS.distribution_strategy = 'off'
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment