Unverified Commit ae3b0a67 authored by Naveenkhasyap's avatar Naveenkhasyap Committed by GitHub
Browse files

#7532 Adding model save support for Sentiment Analysis model (#7715)

* #7532 Adding model save option for Sentiment model

* Updated the code as per review comments.
parent b3e4fefd
...@@ -10,17 +10,20 @@ from __future__ import division ...@@ -10,17 +10,20 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
import tensorflow as tf import tensorflow as tf
from data import dataset from data import dataset
import sentiment_model import sentiment_model
_DROPOUT_RATE = 0.95 _DROPOUT_RATE = 0.95
def run_model(dataset_name, emb_dim, voc_size, sen_len, def run_model(dataset_name, emb_dim, voc_size, sen_len,
hid_dim, batch_size, epochs): hid_dim, batch_size, epochs, model_save_dir):
"""Run training loop and an evaluation at the end. """Run training loop and an evaluation at the end.
Args: Args:
...@@ -48,9 +51,23 @@ def run_model(dataset_name, emb_dim, voc_size, sen_len, ...@@ -48,9 +51,23 @@ def run_model(dataset_name, emb_dim, voc_size, sen_len,
x_train, y_train, x_test, y_test = dataset.load( x_train, y_train, x_test, y_test = dataset.load(
dataset_name, voc_size, sen_len) dataset_name, voc_size, sen_len)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
filepath=model_save_dir+"/model-{epoch:02d}.hdf5"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_accuracy',
verbose=1,save_best_only=True,
save_weights_only=True,mode='auto')
model.fit(x_train, y_train, batch_size=batch_size, model.fit(x_train, y_train, batch_size=batch_size,
validation_split=0.4, epochs=epochs) validation_split=0.4, epochs=epochs, callbacks=[checkpoint_callback])
score = model.evaluate(x_test, y_test, batch_size=batch_size) score = model.evaluate(x_test, y_test, batch_size=batch_size)
model.save(os.path.join(model_save_dir, "full-model.h5"))
tf.logging.info("Score: {}".format(score)) tf.logging.info("Score: {}".format(score))
if __name__ == "__main__": if __name__ == "__main__":
...@@ -85,8 +102,14 @@ if __name__ == "__main__": ...@@ -85,8 +102,14 @@ if __name__ == "__main__":
help="The number of epochs for training.", help="The number of epochs for training.",
type=int, default=55) type=int, default=55)
parser.add_argument("-f", "--folder",
help="folder/dir to save trained model",
type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
if args.folder is None:
parser.error("-f argument folder/dir to save is None,provide path to save model.")
run_model(args.dataset, args.embedding_dim, args.vocabulary_size, run_model(args.dataset, args.embedding_dim, args.vocabulary_size,
args.sentence_length, args.hidden_dim, args.sentence_length, args.hidden_dim,
args.batch_size, args.epochs) args.batch_size, args.epochs, args.folder)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment