Commit 5b6be76b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 335106919
parent e1c78a72
......@@ -182,6 +182,7 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
max_seq_length: int = 64
static_batch: bool = False
vocab_file: str = ''
@data_loader_factory.register_data_loader_cls(WMTDataConfig)
......@@ -196,13 +197,21 @@ class WMTDataLoader(data_loader.DataLoader):
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'targets': tf.io.VarLenFeature(tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
example['targets'] = tf.sparse.to_dense(example['targets'])
if self._params.is_training:
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'targets': tf.io.VarLenFeature(tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
example['targets'] = tf.sparse.to_dense(example['targets'])
else:
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'unique_id': tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
......@@ -224,8 +233,7 @@ class WMTDataLoader(data_loader.DataLoader):
self._global_batch_size) if input_context else self._global_batch_size
if self._static_batch:
padded_shapes = dict([(name, [self._max_seq_length])
for name, _ in dataset.element_spec.items()
])
for name, _ in dataset.element_spec.items()])
dataset = dataset.padded_batch(
int(per_replica_batch_size // self._max_seq_length),
padded_shapes,
......@@ -238,10 +246,27 @@ class WMTDataLoader(data_loader.DataLoader):
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def _inference_padded_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
padded_shapes = {}
for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
return dataset.padded_batch(
per_replica_batch_size, padded_shapes, drop_remainder=True)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
transform_and_batch_fn=self._bucketize_and_batch)
transform_and_batch_fn=self._bucketize_and_batch
if self._params.is_training else self._inference_padded_batch)
return reader.read(input_context)
......@@ -55,6 +55,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
input_path=train_data_path,
max_seq_length=35,
global_batch_size=batch_tokens_size,
is_training=True,
static_batch=False)
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset))
......@@ -64,6 +65,7 @@ class WMTDataLoaderTest(tf.test.TestCase):
input_path=train_data_path,
max_seq_length=35,
global_batch_size=batch_tokens_size,
is_training=True,
static_batch=True)
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset))
......@@ -79,7 +81,8 @@ class WMTDataLoaderTest(tf.test.TestCase):
data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path,
max_seq_length=100,
global_batch_size=batch_tokens_size)
global_batch_size=batch_tokens_size,
is_training=True)
with self.assertRaisesRegex(
ValueError, 'The token budget, global batch size, is too small.*'):
_ = wmt_dataloader.WMTDataLoader(data_config).load()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment