Commit b4b675db authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345588520
parent 55c284c8
......@@ -16,12 +16,6 @@
1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{'inputs': [variable length array of integers],
'targets': [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between
'inputs' and 'targets' length). Each group is then batched such that:
group_batch_size * length <= batch_size.
......@@ -37,32 +31,22 @@
This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly.
"""
from typing import Optional
from typing import Dict, Optional
import dataclasses
import tensorflow as tf
import tensorflow_text as tftxt
from official.core import config_definitions as cfg
from official.core import input_reader
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER = 8 * 1000 * 1000
# Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1
def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(
tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example):
"""Returns the maximum length between the example inputs and targets."""
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
......@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation."""
max_seq_length: int = 64
static_batch: bool = False
vocab_file: str = ''
sentencepiece_model_path: str = ''
src_lang: str = ''
tgt_lang: str = ''
transform_and_batch: bool = True
has_unique_id: bool = False
@data_loader_factory.register_data_loader_cls(WMTDataConfig)
......@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader):
self._max_seq_length = params.max_seq_length
self._static_batch = params.static_batch
self._global_batch_size = params.global_batch_size
if self._params.transform_and_batch:
self._tokenizer = tftxt.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, 'rb').read(),
add_eos=True)
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
if self._params.is_training:
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'targets': tf.io.VarLenFeature(tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
example['targets'] = tf.sparse.to_dense(example['targets'])
else:
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'unique_id': tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
name_to_features = {
self._params.src_lang: tf.io.FixedLenFeature([], tf.string),
self._params.tgt_lang: tf.io.FixedLenFeature([], tf.string),
}
if self._params.has_unique_id:
name_to_features['unique_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
......@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader):
example[name] = t
return example
def _bucketize_and_batch(
def _tokenize(self, inputs) -> Dict[str, tf.Tensor]:
tokenized_inputs = {}
for k, v in inputs.items():
if k == self._params.src_lang:
tokenized_inputs['inputs'] = self._tokenizer.tokenize(v)
elif k == self._params.tgt_lang:
tokenized_inputs['targets'] = self._tokenizer.tokenize(v)
else:
tokenized_inputs[k] = v
print(tokenized_inputs)
return tokenized_inputs
def _filter_max_length(self, inputs):
# return tf.constant(True)
return tf.logical_and(
tf.shape(inputs['inputs'])[0] <= self._max_seq_length,
tf.shape(inputs['targets'])[0] <= self._max_seq_length)
def _maybe_truncate(self, inputs):
truncated_inputs = {}
for k, v in inputs.items():
if k == 'inputs' or k == 'targets':
truncated_inputs[k] = tf.pad(
v[:self._max_seq_length - 1], [[0, 1]],
constant_values=1) if tf.shape(v)[0] > self._max_seq_length else v
else:
truncated_inputs[k] = v
return truncated_inputs
def _tokenize_bucketize_and_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
# pylint: disable=g-long-lambda
dataset = dataset.filter(lambda x: _filter_max_length(
(x['inputs'], x['targets']), self._max_seq_length))
# pylint: enable=g-long-lambda
dataset = dataset.map(
self._tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self._params.is_training:
dataset = dataset.filter(self._filter_max_length)
else:
dataset = dataset.map(
self._maybe_truncate,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
if self._static_batch:
padded_shapes = dict([(name, [self._max_seq_length])
for name, _ in dataset.element_spec.items()])
padded_shapes = {}
for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
batch_size = per_replica_batch_size
if self._params.is_training:
batch_size = int(batch_size // self._max_seq_length)
dataset = dataset.padded_batch(
int(per_replica_batch_size // self._max_seq_length),
batch_size,
padded_shapes,
drop_remainder=True)
else:
......@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader):
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
def _inference_padded_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
padded_shapes = {}
for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
return dataset.padded_batch(
per_replica_batch_size, padded_shapes, drop_remainder=True)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
decoder_fn = None
# Only decode for TFRecords.
if self._params.input_path:
decoder_fn = self._decode
def _identity(
dataset, input_context: Optional[tf.distribute.InputContext] = None):
del input_context
return dataset
transform_and_batch_fn = _identity
if self._params.transform_and_batch:
transform_and_batch_fn = self._tokenize_bucketize_and_batch
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
transform_and_batch_fn=self._bucketize_and_batch
if self._params.is_training else self._inference_padded_batch)
decoder_fn=decoder_fn,
transform_and_batch_fn=transform_and_batch_fn)
return reader.read(input_context)
......@@ -15,74 +15,113 @@
# ==============================================================================
"""Tests for official.nlp.data.wmt_dataloader."""
import os
import random
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.data import wmt_dataloader
def _create_fake_dataset(output_path):
"""Creates a fake dataset."""
writer = tf.io.TFRecordWriter(output_path)
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, 'w') as f:
for l in lines:
f.write('{}\n'.format(l))
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
for _ in range(20):
features = {}
seq_length = random.randint(20, 40)
input_ids = np.random.randint(100, size=(seq_length))
features['inputs'] = create_int_feature(input_ids)
seq_length = random.randint(10, 80)
targets = np.random.randint(100, size=(seq_length))
features['targets'] = create_int_feature(targets)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
writer = tf.io.TFRecordWriter(filepath)
for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
features = {
'en': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
'reverse_en': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}
if unique_id:
features['unique_id'] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[i])),
example = tf.train.Example(
features=tf.train.Features(
feature=features))
writer.write(example.SerializeToString())
writer.close()
class WMTDataLoaderTest(tf.test.TestCase):
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = ' '.join([
f'--input={input_path}', f'--vocab_size={vocab_size}',
'--character_coverage=0.995',
f'--model_prefix={model_path}', '--model_type=bpe',
'--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
])
SentencePieceTrainer.Train(argstr)
def test_load_dataset(self):
batch_tokens_size = 100
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(train_data_path)
data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path,
max_seq_length=35,
global_batch_size=batch_tokens_size,
is_training=True,
static_batch=False)
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset))
inputs, targets = examples['inputs'], examples['targets']
logging.info('dynamic inputs=%s targets=%s', inputs, targets)
class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(WMTDataLoaderTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = [
'abc ede fg',
'bbcd ef a g',
'de f a a g'
]
tgt_lines = [
'dd cc a ef g',
'bcd ef a g',
'gef cd ba'
]
self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
_generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
_generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
unique_id=True)
self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
_train_sentencepiece(self._sentencepeice_input_path, 20,
sentencepeice_model_prefix)
self._sentencepeice_model_path = '{}.model'.format(
sentencepeice_model_prefix)
@parameterized.named_parameters(
('train_static', True, True, 100, (2, 35)),
('train_non_static', True, False, 100, (12, 7)),
('non_train_static', False, True, 3, (3, 35)),
('non_train_non_static', False, False, 50, (2, 7)),)
def test_load_dataset(
self, is_training, static_batch, batch_size, expected_shape):
data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path,
input_path=self._record_train_input_path
if is_training else self._record_test_input_path,
max_seq_length=35,
global_batch_size=batch_tokens_size,
is_training=True,
static_batch=True)
global_batch_size=batch_size,
is_training=is_training,
static_batch=static_batch,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset))
inputs, targets = examples['inputs'], examples['targets']
logging.info('static inputs=%s targets=%s', inputs, targets)
self.assertEqual(inputs.shape, (2, 35))
self.assertEqual(targets.shape, (2, 35))
self.assertEqual(inputs.shape, expected_shape)
self.assertEqual(targets.shape, expected_shape)
def test_load_dataset_raise_invalid_window(self):
batch_tokens_size = 10 # this is too small to form buckets.
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(train_data_path)
data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path,
input_path=self._record_train_input_path,
max_seq_length=100,
global_batch_size=batch_tokens_size,
is_training=True)
is_training=True,
static_batch=False,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
with self.assertRaisesRegex(
ValueError, 'The token budget, global batch size, is too small.*'):
_ = wmt_dataloader.WMTDataLoader(data_config).load()
......
......@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer=None,
decoder_layer=None,
dtype=tf.float32,
eos_id=EOS_ID,
**kwargs):
"""Initialize layers to build Transformer model.
......@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
dtype: float dtype.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super(Seq2SeqTransformer, self).__init__(**kwargs)
......@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self._beam_size = beam_size
self._alpha = alpha
self._dtype = dtype
self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=self._embedding_width,
......@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length,
"dtype": self._dtype,
"eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size,
"alpha": self._alpha,
......@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model):
beam_size=self._beam_size,
alpha=self._alpha,
max_decode_length=max_decode_length,
eos_id=EOS_ID,
eos_id=self._eos_id,
padded_decode=self._padded_decode,
dtype=self._dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment