"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "7fd5fce4b22a9e36c948b733bbcf1fba1321d1e8"
Commit 0ffd6ab0 authored by Jay Shi's avatar Jay Shi Committed by A. Unique TensorFlower
Browse files

Parallelize the batching of the Shakespeare benchmark.

PiperOrigin-RevId: 364541922
parent d5afcc72
......@@ -108,7 +108,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
# Split text into sequence length + 1 chucks to create examples
text_as_int = np.array([char2idx[c] for c in text])
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
sequences = char_dataset.batch(
seq_length + 1, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
def split_input_target(chunk):
input_text = chunk[:-1]
......@@ -116,7 +117,8 @@ def get_dataset(path_to_file, batch_size=None, seq_length=SEQ_LENGTH):
return input_text, tf.one_hot(target_text, len(vocab))
dataset = sequences.map(split_input_target)
dataset = dataset.shuffle(10000).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.batch(
batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE)
return dataset, idx2char, char2idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment