imdb.py 1.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from data.util import pad_sentence, to_dataset, START_CHAR, OOV_CHAR
import tensorflow as tf, numpy as np

NUM_CLASS = 2


def construct_input_fns(vocabulary_size, sentence_length,
                        batch_size, repeat=1):
  """Returns training and evaluation input functions.

  Args:
    vocabulary_size: The number of the most frequent tokens
      to be used from the corpus.
    sentence_length: The number of words in each sentence.
      Longer sentences get cut, shorter ones padded.
    batch_size: Number of data in each batch.
    repeat: The number of epoch.
  Raises:
    ValueError: if the dataset value is not valid.
  Returns:
    A tuple of training and evaluation input function.
  """
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(
      path="imdb.npz",
      num_words=vocabulary_size,
      skip_top=0,
      maxlen=None,
      seed=113,
      start_char=START_CHAR,
      oov_char=OOV_CHAR,
      index_from=OOV_CHAR + 1)

  def train_input_fn():
    dataset = to_dataset(
        np.array([pad_sentence(s, sentence_length) for s in x_train]),
        np.eye(NUM_CLASS)[y_train], batch_size, repeat)
    dataset = dataset.shuffle(len(x_train), reshuffle_each_iteration=True)
    return dataset

  def eval_input_fn():
    dataset = to_dataset(
        np.array([pad_sentence(s, sentence_length) for s in x_test]),
        np.eye(NUM_CLASS)[y_test], batch_size, repeat)
    return dataset

  return train_input_fn, eval_input_fn