Commit 0d1b00b1 authored by Vadim Markovtsev's avatar Vadim Markovtsev
Browse files

Swivel: move the rest of the ops to GPU

parent 89bccc63
...@@ -207,46 +207,43 @@ class SwivelModel(object): ...@@ -207,46 +207,43 @@ class SwivelModel(object):
sys.stdout.flush() sys.stdout.flush()
# ===== CREATE VARIABLES ====== # ===== CREATE VARIABLES ======
# embeddings
with tf.device('/cpu:0'): self.row_embedding = embeddings_with_init(
# embeddings embedding_dim=config.embedding_size,
self.row_embedding = embeddings_with_init( vocab_size=self.n_rows,
embedding_dim=config.embedding_size, name='row_embedding')
vocab_size=self.n_rows, self.col_embedding = embeddings_with_init(
name='row_embedding') embedding_dim=config.embedding_size,
self.col_embedding = embeddings_with_init( vocab_size=self.n_cols,
embedding_dim=config.embedding_size, name='col_embedding')
vocab_size=self.n_cols, tf.summary.histogram('row_emb', self.row_embedding)
name='col_embedding') tf.summary.histogram('col_emb', self.col_embedding)
tf.summary.histogram('row_emb', self.row_embedding)
tf.summary.histogram('col_emb', self.col_embedding) matrix_log_sum = math.log(np.sum(row_sums) + 1)
row_bias_init = [math.log(x + 1) for x in row_sums]
matrix_log_sum = math.log(np.sum(row_sums) + 1) col_bias_init = [math.log(x + 1) for x in col_sums]
row_bias_init = [math.log(x + 1) for x in row_sums] self.row_bias = tf.Variable(
col_bias_init = [math.log(x + 1) for x in col_sums] row_bias_init, trainable=config.trainable_bias)
self.row_bias = tf.Variable(row_bias_init, self.col_bias = tf.Variable(
trainable=config.trainable_bias) col_bias_init, trainable=config.trainable_bias)
self.col_bias = tf.Variable(col_bias_init, tf.summary.histogram('row_bias', self.row_bias)
trainable=config.trainable_bias) tf.summary.histogram('col_bias', self.col_bias)
tf.summary.histogram('row_bias', self.row_bias)
tf.summary.histogram('col_bias', self.col_bias)
# ===== CREATE GRAPH ===== # ===== CREATE GRAPH =====
# Get input # Get input
with tf.device('/cpu:0'): global_row, global_col, count = count_matrix_input(
global_row, global_col, count = count_matrix_input( count_matrix_files, config.submatrix_rows, config.submatrix_cols)
count_matrix_files, config.submatrix_rows, config.submatrix_cols)
# Fetch embeddings.
# Fetch embeddings. selected_row_embedding = tf.nn.embedding_lookup(
selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding, self.row_embedding, global_row)
global_row) selected_col_embedding = tf.nn.embedding_lookup(
selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding, self.col_embedding, global_col)
global_col)
# Fetch biases.
# Fetch biases. selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row) selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
# Multiply the row and column embeddings to generate predictions. # Multiply the row and column embeddings to generate predictions.
predictions = tf.matmul( predictions = tf.matmul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment