swivel.py 15.6 KB
Newer Older
Martin Wicke's avatar
Martin Wicke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python
#
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Submatrix-wise Vector Embedding Learner.

Implementation of SwiVel algorithm described at:
http://arxiv.org/abs/1602.02215

This program expects an input directory that contains the following files.

  row_vocab.txt, col_vocab.txt

    The row an column vocabulary files.  Each file should contain one token per
    line; these will be used to generate a tab-separate file containing the
    trained embeddings.

  row_sums.txt, col_sum.txt

    The matrix row and column marginal sums.  Each file should contain one
    decimal floating point number per line which corresponds to the marginal
    count of the matrix for that row or column.

  shards.recs

    A file containing the sub-matrix shards, stored as TFRecords.  Each shard is
    expected to be a serialzed tf.Example protocol buffer with the following
    properties:

      global_row: the global row indicies contained in the shard
      global_col: the global column indicies contained in the shard
      sparse_local_row, sparse_local_col, sparse_value: three parallel arrays
      that are a sparse representation of the submatrix counts.

It will generate embeddings, training from the input directory for the specified
number of epochs.  When complete, it will output the trained vectors to a
tab-separated file that contains one line per embedding.  Row and column
embeddings are stored in separate files.

"""

54
from __future__ import print_function
Martin Wicke's avatar
Martin Wicke committed
55
56
57
58
59
60
61
62
63
import glob
import math
import os
import sys
import time
import threading

import numpy as np
import tensorflow as tf
64
from tensorflow.python.client import device_lib
Martin Wicke's avatar
Martin Wicke committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

flags = tf.app.flags

flags.DEFINE_string('input_base_path', '/tmp/swivel_data',
                    'Directory containing input shards, vocabularies, '
                    'and marginals.')
flags.DEFINE_string('output_base_path', '/tmp/swivel_data',
                    'Path where to write the trained embeddings.')
flags.DEFINE_integer('embedding_size', 300, 'Size of the embeddings')
flags.DEFINE_boolean('trainable_bias', False, 'Biases are trainable')
flags.DEFINE_integer('submatrix_rows', 4096, 'Rows in each training submatrix. '
                     'This must match the training data.')
flags.DEFINE_integer('submatrix_cols', 4096, 'Rows in each training submatrix. '
                     'This must match the training data.')
flags.DEFINE_float('loss_multiplier', 1.0 / 4096,
                   'constant multiplier on loss.')
flags.DEFINE_float('confidence_exponent', 0.5,
                   'Exponent for l2 confidence function')
flags.DEFINE_float('confidence_scale', 0.25, 'Scale for l2 confidence function')
flags.DEFINE_float('confidence_base', 0.1, 'Base for l2 confidence function')
flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate')
flags.DEFINE_integer('num_concurrent_steps', 2,
                     'Number of threads to train with')
88
89
flags.DEFINE_integer('num_readers', 4,
                     'Number of threads to read the input data and feed it')
Martin Wicke's avatar
Martin Wicke committed
90
flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for')
91
92
93
94
flags.DEFINE_float('per_process_gpu_memory_fraction', 0,
                   'Fraction of GPU memory to use, 0 means allow_growth')
flags.DEFINE_integer('num_gpus', 0,
                     'Number of GPUs to use, 0 means all available')
Martin Wicke's avatar
Martin Wicke committed
95
96
97
98

FLAGS = flags.FLAGS


99
100
101
102
103
104
105
106
107
def log(message, *args, **kwargs):
    tf.logging.info(message, *args, **kwargs)


def get_available_gpus():
    return [d.name for d in device_lib.list_local_devices()
            if d.device_type == 'GPU']


Martin Wicke's avatar
Martin Wicke committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def embeddings_with_init(vocab_size, embedding_dim, name):
  """Creates and initializes the embedding tensors."""
  return tf.get_variable(name=name,
                         shape=[vocab_size, embedding_dim],
                         initializer=tf.random_normal_initializer(
                             stddev=math.sqrt(1.0 / embedding_dim)))


def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
  """Reads submatrix shards from disk."""
  filename_queue = tf.train.string_input_producer(filenames)
  reader = tf.WholeFileReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={
          'global_row': tf.FixedLenFeature([submatrix_rows], dtype=tf.int64),
          'global_col': tf.FixedLenFeature([submatrix_cols], dtype=tf.int64),
          'sparse_local_row': tf.VarLenFeature(dtype=tf.int64),
          'sparse_local_col': tf.VarLenFeature(dtype=tf.int64),
          'sparse_value': tf.VarLenFeature(dtype=tf.float32)
      })

  global_row = features['global_row']
  global_col = features['global_col']

  sparse_local_row = features['sparse_local_row'].values
  sparse_local_col = features['sparse_local_col'].values
  sparse_count = features['sparse_value'].values

James Hwang's avatar
James Hwang committed
138
139
  sparse_indices = tf.concat([tf.expand_dims(sparse_local_row, 1),
                              tf.expand_dims(sparse_local_col, 1)], 1)
Martin Wicke's avatar
Martin Wicke committed
140
141
142
143
144
145
  count = tf.sparse_to_dense(sparse_indices, [submatrix_rows, submatrix_cols],
                             sparse_count)

  queued_global_row, queued_global_col, queued_count = tf.train.batch(
      [global_row, global_col, count],
      batch_size=1,
146
      num_threads=FLAGS.num_readers,
Martin Wicke's avatar
Martin Wicke committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
      capacity=32)

  queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
  queued_global_col = tf.reshape(queued_global_col, [submatrix_cols])
  queued_count = tf.reshape(queued_count, [submatrix_rows, submatrix_cols])

  return queued_global_row, queued_global_col, queued_count


def read_marginals_file(filename):
  """Reads text file with one number per line to an array."""
  with open(filename) as lines:
    return [float(line) for line in lines]


def write_embedding_tensor_to_disk(vocab_path, output_path, sess, embedding):
  """Writes tensor to output_path as tsv"""
  # Fetch the embedding values from the model
  embeddings = sess.run(embedding)

  with open(output_path, 'w') as out_f:
    with open(vocab_path) as vocab_f:
      for index, word in enumerate(vocab_f):
        word = word.strip()
        embedding = embeddings[index]
        out_f.write(word + '\t' + '\t'.join([str(x) for x in embedding]) + '\n')


def write_embeddings_to_disk(config, model, sess):
  """Writes row and column embeddings disk"""
  # Row Embedding
  row_vocab_path = config.input_base_path + '/row_vocab.txt'
  row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
180
  log('Writing row embeddings to: %s', row_embedding_output_path)
Martin Wicke's avatar
Martin Wicke committed
181
182
183
184
185
186
  write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
                                 sess, model.row_embedding)

  # Column Embedding
  col_vocab_path = config.input_base_path + '/col_vocab.txt'
  col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
187
  log('Writing column embeddings to: %s', col_embedding_output_path)
Martin Wicke's avatar
Martin Wicke committed
188
189
190
191
192
193
194
195
196
197
198
199
  write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
                                 sess, model.col_embedding)


class SwivelModel(object):
  """Small class to gather needed pieces from a Graph being built."""

  def __init__(self, config):
    """Construct graph for dmc."""
    self._config = config

    # Create paths to input data files
200
    log('Reading model from: %s', config.input_base_path)
Martin Wicke's avatar
Martin Wicke committed
201
202
203
204
205
206
207
208
209
210
    count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
    row_sums_path = config.input_base_path + '/row_sums.txt'
    col_sums_path = config.input_base_path + '/col_sums.txt'

    # Read marginals
    row_sums = read_marginals_file(row_sums_path)
    col_sums = read_marginals_file(col_sums_path)

    self.n_rows = len(row_sums)
    self.n_cols = len(col_sums)
211
212
    log('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d)',
        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
Martin Wicke's avatar
Martin Wicke committed
213
214
    self.n_submatrices = (self.n_rows * self.n_cols /
                          (config.submatrix_rows * config.submatrix_cols))
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    log('n_submatrices: %d', self.n_submatrices)

    with tf.device('/cpu:0'):
      # ===== CREATE VARIABLES ======
      # Get input
      global_row, global_col, count = count_matrix_input(
        count_matrix_files, config.submatrix_rows, config.submatrix_cols)

      # Embeddings
      self.row_embedding = embeddings_with_init(
        embedding_dim=config.embedding_size,
        vocab_size=self.n_rows,
        name='row_embedding')
      self.col_embedding = embeddings_with_init(
        embedding_dim=config.embedding_size,
        vocab_size=self.n_cols,
        name='col_embedding')
      tf.summary.histogram('row_emb', self.row_embedding)
      tf.summary.histogram('col_emb', self.col_embedding)

      matrix_log_sum = math.log(np.sum(row_sums) + 1)
      row_bias_init = [math.log(x + 1) for x in row_sums]
      col_bias_init = [math.log(x + 1) for x in col_sums]
      self.row_bias = tf.Variable(
          row_bias_init, trainable=config.trainable_bias)
      self.col_bias = tf.Variable(
          col_bias_init, trainable=config.trainable_bias)
      tf.summary.histogram('row_bias', self.row_bias)
      tf.summary.histogram('col_bias', self.col_bias)

      # Add optimizer
      l2_losses = []
      sigmoid_losses = []
      self.global_step = tf.Variable(0, name='global_step')
      opt = tf.train.AdagradOptimizer(config.learning_rate)

      all_grads = []

    devices = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)] \
        if FLAGS.num_gpus > 0 else get_available_gpus()
    self.devices_number = len(devices)
    with tf.variable_scope(tf.get_variable_scope()):
      for dev in devices:
        with tf.device(dev):
          with tf.name_scope(dev[1:].replace(':', '_')):
            # ===== CREATE GRAPH =====
            # Fetch embeddings.
            selected_row_embedding = tf.nn.embedding_lookup(
                self.row_embedding, global_row)
            selected_col_embedding = tf.nn.embedding_lookup(
                self.col_embedding, global_col)

            # Fetch biases.
            selected_row_bias = tf.nn.embedding_lookup(
                [self.row_bias], global_row)
            selected_col_bias = tf.nn.embedding_lookup(
                [self.col_bias], global_col)

            # Multiply the row and column embeddings to generate predictions.
            predictions = tf.matmul(
                selected_row_embedding, selected_col_embedding,
                transpose_b=True)

            # These binary masks separate zero from non-zero values.
            count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
            count_is_zero = 1 - count_is_nonzero

            objectives = count_is_nonzero * tf.log(count + 1e-30)
            objectives -= tf.reshape(
                selected_row_bias, [config.submatrix_rows, 1])
            objectives -= selected_col_bias
            objectives += matrix_log_sum

            err = predictions - objectives

            # The confidence function scales the L2 loss based on the raw
            # co-occurrence count.
            l2_confidence = (config.confidence_base +
                             config.confidence_scale * tf.pow(
                                 count, config.confidence_exponent))

            l2_loss = config.loss_multiplier * tf.reduce_sum(
                0.5 * l2_confidence * err * err * count_is_nonzero)
            l2_losses.append(tf.expand_dims(l2_loss, 0))

            sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
                tf.nn.softplus(err) * count_is_zero)
            sigmoid_losses.append(tf.expand_dims(sigmoid_loss, 0))

            loss = l2_loss + sigmoid_loss
            grads = opt.compute_gradients(loss)
            all_grads.append(grads)

    with tf.device('/cpu:0'):
      # ===== MERGE LOSSES =====
      l2_loss = tf.reduce_mean(tf.concat(l2_losses, 0), 0, name="l2_loss")
      sigmoid_loss = tf.reduce_mean(tf.concat(sigmoid_losses, 0), 0,
                                    name="sigmoid_loss")
      self.loss = l2_loss + sigmoid_loss
      average = tf.train.ExponentialMovingAverage(0.8, self.global_step)
      loss_average_op = average.apply((self.loss,))
      tf.summary.scalar("l2_loss", l2_loss)
      tf.summary.scalar("sigmoid_loss", sigmoid_loss)
      tf.summary.scalar("loss", self.loss)

      # Apply the gradients to adjust the shared variables.
      apply_gradient_ops = []
      for grads in all_grads:
        apply_gradient_ops.append(opt.apply_gradients(
            grads, global_step=self.global_step))

      self.train_op = tf.group(loss_average_op, *apply_gradient_ops)
      self.saver = tf.train.Saver(sharded=True)
Martin Wicke's avatar
Martin Wicke committed
328
329
330


def main(_):
331
332
333
  tf.logging.set_verbosity(tf.logging.INFO)
  start_time = time.time()

Martin Wicke's avatar
Martin Wicke committed
334
335
336
337
338
339
340
341
342
343
  # Create the output path.  If this fails, it really ought to fail
  # now. :)
  if not os.path.isdir(FLAGS.output_base_path):
    os.makedirs(FLAGS.output_base_path)

  # Create and run model
  with tf.Graph().as_default():
    model = SwivelModel(FLAGS)

    # Create a session for running Ops on the Graph.
344
345
346
347
348
349
350
    gpu_opts = {}
    if FLAGS.per_process_gpu_memory_fraction > 0:
        gpu_opts["per_process_gpu_memory_fraction"] = \
            FLAGS.per_process_gpu_memory_fraction
    else:
        gpu_opts["allow_growth"] = True
    gpu_options = tf.GPUOptions(**gpu_opts)
Martin Wicke's avatar
Martin Wicke committed
351
352
353
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    # Run the Op to initialize the variables.
James Hwang's avatar
James Hwang committed
354
    sess.run(tf.global_variables_initializer())
Martin Wicke's avatar
Martin Wicke committed
355
356
357
358
359
360
361
362

    # Start feeding input
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # Calculate how many steps each thread should run
    n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
        FLAGS.submatrix_rows * FLAGS.submatrix_cols)
363
364
    n_steps_per_thread = n_total_steps / (
        FLAGS.num_concurrent_steps * model.devices_number)
Martin Wicke's avatar
Martin Wicke committed
365
366
    n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
    t0 = [time.time()]
367
368
369
370
371
    n_steps_between_status_updates = 100
    status_i = [0]
    status_lock = threading.Lock()
    msg = ('%%%dd/%%d submatrices trained (%%.1f%%%%), %%5.1f submatrices/sec |'
           ' loss %%f') % len(str(n_submatrices_to_train))
Martin Wicke's avatar
Martin Wicke committed
372
373

    def TrainingFn():
374
      for _ in range(int(n_steps_per_thread)):
375
376
377
378
379
380
381
382
383
384
        _, global_step, loss = sess.run((
            model.train_op, model.global_step, model.loss))

        show_status = False
        with status_lock:
          new_i = global_step // n_steps_between_status_updates
          if new_i > status_i[0]:
            status_i[0] = new_i
            show_status = True
        if show_status:
Martin Wicke's avatar
Martin Wicke committed
385
          elapsed = float(time.time() - t0[0])
386
          log(msg, global_step, n_submatrices_to_train,
Martin Wicke's avatar
Martin Wicke committed
387
              100.0 * global_step / n_submatrices_to_train,
388
              n_steps_between_status_updates / elapsed, loss)
Martin Wicke's avatar
Martin Wicke committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
          t0[0] = time.time()

    # Start training threads
    train_threads = []
    for _ in range(FLAGS.num_concurrent_steps):
      t = threading.Thread(target=TrainingFn)
      train_threads.append(t)
      t.start()

    # Wait for threads to finish.
    for t in train_threads:
      t.join()

    coord.request_stop()
    coord.join(threads)

    # Write out vectors
    write_embeddings_to_disk(FLAGS, model, sess)

408
    # Shutdown
Martin Wicke's avatar
Martin Wicke committed
409
    sess.close()
410
    log("Elapsed: %s", time.time() - start_time)
Martin Wicke's avatar
Martin Wicke committed
411
412
413
414


if __name__ == '__main__':
  tf.app.run()