dataset.py 1.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""Dataset module for sentiment analysis.

Currently imdb dataset is available.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

10
11
12
13
14
import data.imdb as imdb

DATASET_IMDB = "imdb"


15
16
def load(dataset, vocabulary_size, sentence_length):
  """Returns training and evaluation input.
17
18
19
20
21
22
23
24
25
26
27

  Args:
    dataset: Dataset to be trained and evaluated.
      Currently only imdb is supported.
    vocabulary_size: The number of the most frequent tokens
      to be used from the corpus.
    sentence_length: The number of words in each sentence.
      Longer sentences get cut, shorter ones padded.
  Raises:
    ValueError: if the dataset value is not valid.
  Returns:
28
29
30
    A tuple of length 4, for training sentences, labels,
    evaluation sentences, and evaluation labels,
    each being an numpy array.
31
32
  """
  if dataset == DATASET_IMDB:
33
    return imdb.load(vocabulary_size, sentence_length)
34
35
36
37
38
39
40
41
42
43
44
45
46
  else:
    raise ValueError("unsupported dataset: " + dataset)


def get_num_class(dataset):
  """Returns an integer for the number of label classes.

  Args:
    dataset: Dataset to be trained and evaluated.
      Currently only imdb is supported.
  Raises:
    ValueError: if the dataset value is not valid.
  Returns:
47
    int: The number of label classes.
48
49
50
51
52
  """
  if dataset == DATASET_IMDB:
    return imdb.NUM_CLASS
  else:
    raise ValueError("unsupported dataset: " + dataset)