Commit 0ffd6ab0 authored by Jay Shi's avatar Jay Shi Committed by A. Unique TensorFlower
Browse files

Parallelize the batching of the Shakespeare benchmark.

PiperOrigin-RevId: 364541922
parent d5afcc72
...@@ -108,7 +108,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH): ...@@ -108,7 +108,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
# Split text into sequence length + 1 chucks to create examples # Split text into sequence length + 1 chucks to create examples
text_as_int = np.array([char2idx[c] for c in text]) text_as_int = np.array([char2idx[c] for c in text])
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int) char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length+1, drop_remainder=True) sequences = char_dataset.batch(
seq_length + 1, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
def split_input_target(chunk): def split_input_target(chunk):
input_text = chunk[:-1] input_text = chunk[:-1]
...@@ -116,7 +117,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH): ...@@ -116,7 +117,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
return input_text, tf.one_hot(target_text, len(vocab)) return input_text, tf.one_hot(target_text, len(vocab))
dataset = sequences.map(split_input_target) dataset = sequences.map(split_input_target)
dataset = dataset.shuffle(10000).repeat() dataset = dataset.shuffle(10000).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(
batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
return dataset, idx2char, char2idx return dataset, idx2char, char2idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment