dataset.py 1.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import data.imdb as imdb

DATASET_IMDB = "imdb"


def construct_input_fns(dataset, batch_size, vocabulary_size,
                        sentence_length, repeat=1):
  """Returns training and evaluation input functions.

  Args:
    dataset: Dataset to be trained and evaluated.
      Currently only imdb is supported.
    batch_size: Number of data in each batch.
    vocabulary_size: The number of the most frequent tokens
      to be used from the corpus.
    sentence_length: The number of words in each sentence.
      Longer sentences get cut, shorter ones padded.
    repeat: The number of epoch.
  Raises:
    ValueError: if the dataset value is not valid.
  Returns:
    A tuple of training and evaluation input function.
  """
  if dataset == DATASET_IMDB:
    train_input_fn, eval_input_fn = imdb.construct_input_fns(
        vocabulary_size, sentence_length, batch_size, repeat=repeat)
    return train_input_fn, eval_input_fn
  else:
    raise ValueError("unsupported dataset: " + dataset)


def get_num_class(dataset):
  """Returns an integer for the number of label classes.

  Args:
    dataset: Dataset to be trained and evaluated.
      Currently only imdb is supported.
  Raises:
    ValueError: if the dataset value is not valid.
  Returns:
    str: The dataset name.
  """
  if dataset == DATASET_IMDB:
    return imdb.NUM_CLASS
  else:
    raise ValueError("unsupported dataset: " + dataset)