learn_path_embeddings.py 6.59 KB
Newer Older
Chris Waterson's avatar
Chris Waterson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python
# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Trains the LexNET path-based model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import lexnet_common
import path_model
from sklearn import metrics
import tensorflow as tf

29
30
31
32
33
34
35
tf.flags.DEFINE_string('train', '', 'training dataset, tfrecs')
tf.flags.DEFINE_string('val', '', 'validation dataset, tfrecs')
tf.flags.DEFINE_string('test', '', 'test dataset, tfrecs')
tf.flags.DEFINE_string('embeddings', '', 'embeddings, npy')
tf.flags.DEFINE_string('relations', '', 'file containing relation labels')
tf.flags.DEFINE_string('output_dir', '', 'output directory for path embeddings')
tf.flags.DEFINE_string('logdir', '', 'directory for model training')
Chris Waterson's avatar
Chris Waterson committed
36
37
38
39
40
41
42
FLAGS = tf.flags.FLAGS


def main(_):
  # Pick up any one-off hyper-parameters.
  hparams = path_model.PathBasedModel.default_hparams()

43
44
  with open(FLAGS.relations) as fh:
    relations = fh.read().splitlines()
Chris Waterson's avatar
Chris Waterson committed
45

46
  hparams.num_classes = len(relations)
Chris Waterson's avatar
Chris Waterson committed
47
48
49
50
51
52
53
54
  print('Model will predict into %d classes' % hparams.num_classes)

  print('Running with hyper-parameters: {}'.format(hparams))

  # Load the instances
  print('Loading instances...')
  opts = tf.python_io.TFRecordOptions(
      compression_type=tf.python_io.TFRecordCompressionType.GZIP)
55
56
57
58

  train_instances = list(tf.python_io.tf_record_iterator(FLAGS.train, opts))
  val_instances = list(tf.python_io.tf_record_iterator(FLAGS.val, opts))
  test_instances = list(tf.python_io.tf_record_iterator(FLAGS.test, opts))
Chris Waterson's avatar
Chris Waterson committed
59
60
61

  # Load the word embeddings
  print('Loading word embeddings...')
62
  lemma_embeddings = lexnet_common.load_word_embeddings(FLAGS.embeddings)
Chris Waterson's avatar
Chris Waterson committed
63
64
65
66
67
68
69
70

  # Define the graph and the model
  with tf.Graph().as_default():
    with tf.variable_scope('lexnet'):
      options = tf.python_io.TFRecordOptions(
          compression_type=tf.python_io.TFRecordCompressionType.GZIP)
      reader = tf.TFRecordReader(options=options)
      _, train_instance = reader.read(
71
          tf.train.string_input_producer([FLAGS.train]))
Chris Waterson's avatar
Chris Waterson committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
      shuffled_train_instance = tf.train.shuffle_batch(
          [train_instance],
          batch_size=1,
          num_threads=1,
          capacity=len(train_instances),
          min_after_dequeue=100,
      )[0]

      train_model = path_model.PathBasedModel(
          hparams, lemma_embeddings, shuffled_train_instance)

    with tf.variable_scope('lexnet', reuse=True):
      val_instance = tf.placeholder(dtype=tf.string)
      val_model = path_model.PathBasedModel(
          hparams, lemma_embeddings, val_instance)

    # Initialize a session and start training
    best_model_saver = tf.train.Saver()
    f1_t = tf.placeholder(tf.float32)
    best_f1_t = tf.Variable(0.0, trainable=False, name='best_f1')
    assign_best_f1_op = tf.assign(best_f1_t, f1_t)

    supervisor = tf.train.Supervisor(
95
        logdir=FLAGS.logdir,
Chris Waterson's avatar
Chris Waterson committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        global_step=train_model.global_step)

    with supervisor.managed_session() as session:
      # Load the labels
      print('Loading labels...')
      val_labels = train_model.load_labels(session, val_instances)

      # Train the model
      print('Training the model...')

      while True:
        step = session.run(train_model.global_step)
        epoch = (step + len(train_instances) - 1) // len(train_instances)
        if epoch > hparams.num_epochs:
          break

        print('Starting epoch %d (step %d)...' % (1 + epoch, step))

        epoch_loss = train_model.run_one_epoch(session, len(train_instances))

        best_f1 = session.run(best_f1_t)
        f1 = epoch_completed(val_model, session, epoch, epoch_loss,
                             val_instances, val_labels, best_model_saver,
119
                             FLAGS.logdir, best_f1)
Chris Waterson's avatar
Chris Waterson committed
120
121
122
123
124

        if f1 > best_f1:
          session.run(assign_best_f1_op, {f1_t: f1})

        if f1 < best_f1 - 0.08:
125
          tf.logging.info('Stopping training after %d epochs.\n' % epoch)
Chris Waterson's avatar
Chris Waterson committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
          break

      # Print the best performance on the validation set
      best_f1 = session.run(best_f1_t)
      print('Best performance on the validation set: F1=%.3f' % best_f1)

      # Save the path embeddings
      print('Computing the path embeddings...')
      instances = train_instances + val_instances + test_instances
      path_index, path_vectors = path_model.compute_path_embeddings(
          val_model, session, instances)

      if not os.path.exists(path_emb_dir):
        os.makedirs(path_emb_dir)

      path_model.save_path_embeddings(
142
          val_model, path_vectors, path_index, FLAGS.output_dir)
Chris Waterson's avatar
Chris Waterson committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


def epoch_completed(model, session, epoch, epoch_loss,
                    val_instances, val_labels, saver, save_path, best_f1):
  """Runs every time an epoch completes.

  Print the performance on the validation set, and update the saved model if
  its performance is better on the previous ones. If the performance dropped,
  tell the training to stop.

  Args:
    model: The currently trained path-based model.
    session: The current TensorFlow session.
    epoch: The epoch number.
    epoch_loss: The current epoch loss.
    val_instances: The validation set instances (evaluation between epochs).
    val_labels: The validation set labels (for evaluation between epochs).
    saver: tf.Saver object
    save_path: Where to save the model.
    best_f1: the best F1 achieved so far.

  Returns:
    The F1 achieved on the training set.
  """
  # Evaluate on the validation set
  val_pred = model.predict(session, val_instances)
  precision, recall, f1, _ = metrics.precision_recall_fscore_support(
      val_labels, val_pred, average='weighted')
  print(
      'Epoch: %d/%d, Loss: %f, validation set: P: %.3f, R: %.3f, F1: %.3f\n' % (
          epoch + 1, model.hparams.num_epochs, epoch_loss,
          precision, recall, f1))

  if f1 > best_f1:
177
178
179
180
    save_filename = os.path.join(save_path, 'best.ckpt')
    print('Saving model in: %s' % save_filename)
    saver.save(session, save_filename)
    print('Model saved in file: %s' % save_filename)
Chris Waterson's avatar
Chris Waterson committed
181
182
183
184
185
186

  return f1


if __name__ == '__main__':
  tf.app.run(main)