sentiment_main.py 3.97 KB
Newer Older
1
"""Main function for the sentiment analysis model.
2

3
4
5
The model makes use of concatenation of two CNN layers with
different kernel sizes. See `sentiment_model.py`
for more details about the models.
6
"""
7

8
9
10
11
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

12
import argparse
13
import os 
14

15
import tensorflow as tf
16

17
18
from data import dataset
import sentiment_model
19

20
21


22
_DROPOUT_RATE = 0.95
23
24


25
def run_model(dataset_name, emb_dim, voc_size, sen_len,
26
              hid_dim, batch_size, epochs, model_save_dir):
27
  """Run training loop and an evaluation at the end.
28

29
30
31
32
33
34
35
36
37
38
39
  Args:
    dataset_name: Dataset name to be trained and evaluated.
    emb_dim: The dimension of the Embedding layer.
    voc_size: The number of the most frequent tokens
      to be used from the corpus.
    sen_len: The number of words in each sentence.
      Longer sentences get cut, shorter ones padded.
    hid_dim: The dimension of the Embedding layer.
    batch_size: The size of each batch during training.
    epochs: The number of the iteration over the training set for training.
  """
40

41
42
43
44
  model = sentiment_model.CNN(emb_dim, voc_size, sen_len,
                              hid_dim, dataset.get_num_class(dataset_name),
                              _DROPOUT_RATE)
  model.summary()
45

46
47
48
  model.compile(loss="categorical_crossentropy",
                optimizer="rmsprop",
                metrics=["accuracy"])
49

50
51
52
  tf.logging.info("Loading the data")
  x_train, y_train, x_test, y_test = dataset.load(
      dataset_name, voc_size, sen_len)
53

54
55
56
57
58
59
60
61
62
63
  if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)
    
  filepath=model_save_dir+"/model-{epoch:02d}.hdf5"
  
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_accuracy',
                                            verbose=1,save_best_only=True, 
                                            save_weights_only=True,mode='auto')


64
  model.fit(x_train, y_train, batch_size=batch_size,
65
66
            validation_split=0.4, epochs=epochs, callbacks=[checkpoint_callback])

67
  score = model.evaluate(x_test, y_test, batch_size=batch_size)
68
69
70
  
  model.save(os.path.join(model_save_dir, "full-model.h5"))

71
  tf.logging.info("Score: {}".format(score))
72
73

if __name__ == "__main__":
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  parser = argparse.ArgumentParser()
  parser.add_argument("-d", "--dataset", help="Dataset to be trained "
                                              "and evaluated.",
                      type=str, choices=["imdb"], default="imdb")

  parser.add_argument("-e", "--embedding_dim",
                      help="The dimension of the Embedding layer.",
                      type=int, default=512)

  parser.add_argument("-v", "--vocabulary_size",
                      help="The number of the words to be considered "
                           "in the dataset corpus.",
                      type=int, default=6000)

  parser.add_argument("-s", "--sentence_length",
                      help="The number of words in a data point."
                           "Entries of smaller length are padded.",
                      type=int, default=600)

  parser.add_argument("-c", "--hidden_dim",
                      help="The number of the CNN layer filters.",
                      type=int, default=512)

  parser.add_argument("-b", "--batch_size",
                      help="The size of each batch for training.",
                      type=int, default=500)

  parser.add_argument("-p", "--epochs",
                      help="The number of epochs for training.",
                      type=int, default=55)

105
106
107
  parser.add_argument("-f", "--folder",
                      help="folder/dir to save trained model",
                      type=str, default=None)
108
109
  args = parser.parse_args()

110
111
112
  if args.folder is None:
    parser.error("-f argument folder/dir to save is None,provide path to save model.")
  
113
114
  run_model(args.dataset, args.embedding_dim, args.vocabulary_size,
            args.sentence_length, args.hidden_dim,
115
            args.batch_size, args.epochs, args.folder)