"docs/en/vscode:/vscode.git/clone" did not exist on "daf74733ed06e4da163525e0e0523921ff33e46e"
Commit f5014889 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312194218
parent 25160730
......@@ -51,7 +51,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
from absl import logging
......@@ -157,7 +156,7 @@ def _batch_examples(dataset, batch_size, max_length):
# Create list of batch sizes for each bucket_id, so that
# bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
bucket_batch_sizes = [batch_size // x for x in buckets_max]
bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
# bucket_id will be a tensor, so convert this list to a tensor as well.
bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
......@@ -270,7 +269,8 @@ def _read_and_batch_from_files(
def _generate_synthetic_data(params):
"""Create synthetic data based on the parameter batch size."""
batch = length = int(math.sqrt(params["batch_size"]))
batch_size = int(params["batch_size"] // params["max_length"])
length = params["max_length"]
dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([length]),
input_value=1,
......@@ -279,7 +279,11 @@ def _generate_synthetic_data(params):
label_value=1,
label_dtype=tf.int64,
)
return dataset.batch(batch, drop_remainder=True)
if params["static_batch"]:
dataset = dataset.batch(batch_size, drop_remainder=True)
else:
dataset = dataset.padded_batch(batch_size, ([None], [None]))
return dataset
def train_input_fn(params, ctx=None):
......
......@@ -168,8 +168,6 @@ class TransformerTask(object):
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else:
logging.info("Running transformer with num_gpus = %d", num_gpus)
......
......@@ -61,6 +61,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.train_steps = 2
FLAGS.validation_steps = 1
FLAGS.batch_size = 8
FLAGS.max_length = 1
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off'
FLAGS.dtype = 'fp32'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment