Commit b4b675db authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345588520
parent 55c284c8
...@@ -16,12 +16,6 @@ ...@@ -16,12 +16,6 @@
1. Batching scheme 1. Batching scheme
The examples encoded in the TFRecord files contain data in the format:
{'inputs': [variable length array of integers],
'targets': [variable length array of integers]}
Where integers in the arrays refer to tokens in the English and German vocab
file (named `vocab.ende.32768`).
Prior to batching, elements in the dataset are grouped by length (max between Prior to batching, elements in the dataset are grouped by length (max between
'inputs' and 'targets' length). Each group is then batched such that: 'inputs' and 'targets' length). Each group is then batched such that:
group_batch_size * length <= batch_size. group_batch_size * length <= batch_size.
...@@ -37,32 +31,22 @@ ...@@ -37,32 +31,22 @@
This batching scheme decreases the fraction of padding tokens per training This batching scheme decreases the fraction of padding tokens per training
batch, thus improving the training speed significantly. batch, thus improving the training speed significantly.
""" """
from typing import Optional from typing import Dict, Optional
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import tensorflow_text as tftxt
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import input_reader from official.core import input_reader
from official.nlp.data import data_loader from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
# Buffer size for reading records from a TFRecord file. Each training file is
# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
_READ_RECORD_BUFFER = 8 * 1000 * 1000
# Example grouping constants. Defines length boundaries for each group. # Example grouping constants. Defines length boundaries for each group.
# These values are the defaults used in Tensor2Tensor. # These values are the defaults used in Tensor2Tensor.
_MIN_BOUNDARY = 8 _MIN_BOUNDARY = 8
_BOUNDARY_SCALE = 1.1 _BOUNDARY_SCALE = 1.1
def _filter_max_length(example, max_length=256):
"""Indicates whether the example's length is lower than the maximum length."""
return tf.logical_and(
tf.size(example[0]) <= max_length,
tf.size(example[1]) <= max_length)
def _get_example_length(example): def _get_example_length(example):
"""Returns the maximum length between the example inputs and targets.""" """Returns the maximum length between the example inputs and targets."""
length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0]) length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
...@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig): ...@@ -181,7 +165,11 @@ class WMTDataConfig(cfg.DataConfig):
"""Data config for WMT translation.""" """Data config for WMT translation."""
max_seq_length: int = 64 max_seq_length: int = 64
static_batch: bool = False static_batch: bool = False
vocab_file: str = '' sentencepiece_model_path: str = ''
src_lang: str = ''
tgt_lang: str = ''
transform_and_batch: bool = True
has_unique_id: bool = False
@data_loader_factory.register_data_loader_cls(WMTDataConfig) @data_loader_factory.register_data_loader_cls(WMTDataConfig)
...@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader): ...@@ -193,24 +181,20 @@ class WMTDataLoader(data_loader.DataLoader):
self._max_seq_length = params.max_seq_length self._max_seq_length = params.max_seq_length
self._static_batch = params.static_batch self._static_batch = params.static_batch
self._global_batch_size = params.global_batch_size self._global_batch_size = params.global_batch_size
if self._params.transform_and_batch:
self._tokenizer = tftxt.SentencepieceTokenizer(
model=tf.io.gfile.GFile(params.sentencepiece_model_path, 'rb').read(),
add_eos=True)
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
if self._params.is_training: name_to_features = {
name_to_features = { self._params.src_lang: tf.io.FixedLenFeature([], tf.string),
'inputs': tf.io.VarLenFeature(tf.int64), self._params.tgt_lang: tf.io.FixedLenFeature([], tf.string),
'targets': tf.io.VarLenFeature(tf.int64) }
} if self._params.has_unique_id:
example = tf.io.parse_single_example(record, name_to_features) name_to_features['unique_id'] = tf.io.FixedLenFeature([], tf.int64)
example['inputs'] = tf.sparse.to_dense(example['inputs']) example = tf.io.parse_single_example(record, name_to_features)
example['targets'] = tf.sparse.to_dense(example['targets'])
else:
name_to_features = {
'inputs': tf.io.VarLenFeature(tf.int64),
'unique_id': tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(record, name_to_features)
example['inputs'] = tf.sparse.to_dense(example['inputs'])
# tf.Example only supports tf.int64, but the TPU only supports tf.int32. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32. # So cast all int64 to int32.
for name in example: for name in example:
...@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader): ...@@ -220,21 +204,64 @@ class WMTDataLoader(data_loader.DataLoader):
example[name] = t example[name] = t
return example return example
def _bucketize_and_batch( def _tokenize(self, inputs) -> Dict[str, tf.Tensor]:
tokenized_inputs = {}
for k, v in inputs.items():
if k == self._params.src_lang:
tokenized_inputs['inputs'] = self._tokenizer.tokenize(v)
elif k == self._params.tgt_lang:
tokenized_inputs['targets'] = self._tokenizer.tokenize(v)
else:
tokenized_inputs[k] = v
print(tokenized_inputs)
return tokenized_inputs
def _filter_max_length(self, inputs):
# return tf.constant(True)
return tf.logical_and(
tf.shape(inputs['inputs'])[0] <= self._max_seq_length,
tf.shape(inputs['targets'])[0] <= self._max_seq_length)
def _maybe_truncate(self, inputs):
truncated_inputs = {}
for k, v in inputs.items():
if k == 'inputs' or k == 'targets':
truncated_inputs[k] = tf.pad(
v[:self._max_seq_length - 1], [[0, 1]],
constant_values=1) if tf.shape(v)[0] > self._max_seq_length else v
else:
truncated_inputs[k] = v
return truncated_inputs
def _tokenize_bucketize_and_batch(
self, self,
dataset, dataset,
input_context: Optional[tf.distribute.InputContext] = None): input_context: Optional[tf.distribute.InputContext] = None):
# pylint: disable=g-long-lambda dataset = dataset.map(
dataset = dataset.filter(lambda x: _filter_max_length( self._tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
(x['inputs'], x['targets']), self._max_seq_length))
# pylint: enable=g-long-lambda if self._params.is_training:
dataset = dataset.filter(self._filter_max_length)
else:
dataset = dataset.map(
self._maybe_truncate,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size self._global_batch_size) if input_context else self._global_batch_size
if self._static_batch: if self._static_batch:
padded_shapes = dict([(name, [self._max_seq_length]) padded_shapes = {}
for name, _ in dataset.element_spec.items()]) for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
batch_size = per_replica_batch_size
if self._params.is_training:
batch_size = int(batch_size // self._max_seq_length)
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
int(per_replica_batch_size // self._max_seq_length), batch_size,
padded_shapes, padded_shapes,
drop_remainder=True) drop_remainder=True)
else: else:
...@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader): ...@@ -245,27 +272,24 @@ class WMTDataLoader(data_loader.DataLoader):
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset return dataset
def _inference_padded_batch(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
padded_shapes = {}
for name, _ in dataset.element_spec.items():
if name == 'unique_id':
padded_shapes[name] = []
else:
padded_shapes[name] = [self._max_seq_length
] if self._static_batch else [None]
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
return dataset.padded_batch(
per_replica_batch_size, padded_shapes, drop_remainder=True)
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
decoder_fn = None
# Only decode for TFRecords.
if self._params.input_path:
decoder_fn = self._decode
def _identity(
dataset, input_context: Optional[tf.distribute.InputContext] = None):
del input_context
return dataset
transform_and_batch_fn = _identity
if self._params.transform_and_batch:
transform_and_batch_fn = self._tokenize_bucketize_and_batch
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, params=self._params,
decoder_fn=self._decode, decoder_fn=decoder_fn,
transform_and_batch_fn=self._bucketize_and_batch transform_and_batch_fn=transform_and_batch_fn)
if self._params.is_training else self._inference_padded_batch)
return reader.read(input_context) return reader.read(input_context)
...@@ -15,74 +15,113 @@ ...@@ -15,74 +15,113 @@
# ============================================================================== # ==============================================================================
"""Tests for official.nlp.data.wmt_dataloader.""" """Tests for official.nlp.data.wmt_dataloader."""
import os import os
import random from absl.testing import parameterized
from absl import logging
import numpy as np
import tensorflow as tf import tensorflow as tf
from sentencepiece import SentencePieceTrainer
from official.nlp.data import wmt_dataloader from official.nlp.data import wmt_dataloader
def _create_fake_dataset(output_path): def _generate_line_file(filepath, lines):
"""Creates a fake dataset.""" with tf.io.gfile.GFile(filepath, 'w') as f:
writer = tf.io.TFRecordWriter(output_path) for l in lines:
f.write('{}\n'.format(l))
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
for _ in range(20): def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
features = {} writer = tf.io.TFRecordWriter(filepath)
seq_length = random.randint(20, 40) for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
input_ids = np.random.randint(100, size=(seq_length)) features = {
features['inputs'] = create_int_feature(input_ids) 'en': tf.train.Feature(
seq_length = random.randint(10, 80) bytes_list=tf.train.BytesList(
targets = np.random.randint(100, size=(seq_length)) value=[src.encode()])),
features['targets'] = create_int_feature(targets) 'reverse_en': tf.train.Feature(
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) bytes_list=tf.train.BytesList(
writer.write(tf_example.SerializeToString()) value=[tgt.encode()])),
}
if unique_id:
features['unique_id'] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[i])),
example = tf.train.Example(
features=tf.train.Features(
feature=features))
writer.write(example.SerializeToString())
writer.close() writer.close()
class WMTDataLoaderTest(tf.test.TestCase): def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = ' '.join([
f'--input={input_path}', f'--vocab_size={vocab_size}',
'--character_coverage=0.995',
f'--model_prefix={model_path}', '--model_type=bpe',
'--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
])
SentencePieceTrainer.Train(argstr)
def test_load_dataset(self):
batch_tokens_size = 100 class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(train_data_path) def setUp(self):
data_config = wmt_dataloader.WMTDataConfig( super(WMTDataLoaderTest, self).setUp()
input_path=train_data_path, self._temp_dir = self.get_temp_dir()
max_seq_length=35, src_lines = [
global_batch_size=batch_tokens_size, 'abc ede fg',
is_training=True, 'bbcd ef a g',
static_batch=False) 'de f a a g'
dataset = wmt_dataloader.WMTDataLoader(data_config).load() ]
examples = next(iter(dataset)) tgt_lines = [
inputs, targets = examples['inputs'], examples['targets'] 'dd cc a ef g',
logging.info('dynamic inputs=%s targets=%s', inputs, targets) 'bcd ef a g',
'gef cd ba'
]
self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
_generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
_generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
unique_id=True)
self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
_train_sentencepiece(self._sentencepeice_input_path, 20,
sentencepeice_model_prefix)
self._sentencepeice_model_path = '{}.model'.format(
sentencepeice_model_prefix)
@parameterized.named_parameters(
('train_static', True, True, 100, (2, 35)),
('train_non_static', True, False, 100, (12, 7)),
('non_train_static', False, True, 3, (3, 35)),
('non_train_non_static', False, False, 50, (2, 7)),)
def test_load_dataset(
self, is_training, static_batch, batch_size, expected_shape):
data_config = wmt_dataloader.WMTDataConfig( data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path, input_path=self._record_train_input_path
if is_training else self._record_test_input_path,
max_seq_length=35, max_seq_length=35,
global_batch_size=batch_tokens_size, global_batch_size=batch_size,
is_training=True, is_training=is_training,
static_batch=True) static_batch=static_batch,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
dataset = wmt_dataloader.WMTDataLoader(data_config).load() dataset = wmt_dataloader.WMTDataLoader(data_config).load()
examples = next(iter(dataset)) examples = next(iter(dataset))
inputs, targets = examples['inputs'], examples['targets'] inputs, targets = examples['inputs'], examples['targets']
logging.info('static inputs=%s targets=%s', inputs, targets) self.assertEqual(inputs.shape, expected_shape)
self.assertEqual(inputs.shape, (2, 35)) self.assertEqual(targets.shape, expected_shape)
self.assertEqual(targets.shape, (2, 35))
def test_load_dataset_raise_invalid_window(self): def test_load_dataset_raise_invalid_window(self):
batch_tokens_size = 10 # this is too small to form buckets. batch_tokens_size = 10 # this is too small to form buckets.
train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
_create_fake_dataset(train_data_path)
data_config = wmt_dataloader.WMTDataConfig( data_config = wmt_dataloader.WMTDataConfig(
input_path=train_data_path, input_path=self._record_train_input_path,
max_seq_length=100, max_seq_length=100,
global_batch_size=batch_tokens_size, global_batch_size=batch_tokens_size,
is_training=True) is_training=True,
static_batch=False,
src_lang='en',
tgt_lang='reverse_en',
sentencepiece_model_path=self._sentencepeice_model_path)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'The token budget, global batch size, is too small.*'): ValueError, 'The token budget, global batch size, is too small.*'):
_ = wmt_dataloader.WMTDataLoader(data_config).load() _ = wmt_dataloader.WMTDataLoader(data_config).load()
......
...@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -53,6 +53,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer=None, encoder_layer=None,
decoder_layer=None, decoder_layer=None,
dtype=tf.float32, dtype=tf.float32,
eos_id=EOS_ID,
**kwargs): **kwargs):
"""Initialize layers to build Transformer model. """Initialize layers to build Transformer model.
...@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -69,6 +70,7 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_layer: An initialized encoder layer. encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer. decoder_layer: An initialized decoder layer.
dtype: float dtype. dtype: float dtype.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments. **kwargs: other keyword arguments.
""" """
super(Seq2SeqTransformer, self).__init__(**kwargs) super(Seq2SeqTransformer, self).__init__(**kwargs)
...@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -81,6 +83,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self._beam_size = beam_size self._beam_size = beam_size
self._alpha = alpha self._alpha = alpha
self._dtype = dtype self._dtype = dtype
self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding( self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
embedding_width=self._embedding_width, embedding_width=self._embedding_width,
...@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -102,6 +105,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"padded_decode": self._padded_decode, "padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length, "decode_max_length": self._decode_max_length,
"dtype": self._dtype, "dtype": self._dtype,
"eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length, "extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size, "beam_size": self._beam_size,
"alpha": self._alpha, "alpha": self._alpha,
...@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -226,7 +230,7 @@ class Seq2SeqTransformer(tf.keras.Model):
beam_size=self._beam_size, beam_size=self._beam_size,
alpha=self._alpha, alpha=self._alpha,
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=EOS_ID, eos_id=self._eos_id,
padded_decode=self._padded_decode, padded_decode=self._padded_decode,
dtype=self._dtype) dtype=self._dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment