Commit 0d1b00b1 authored by Vadim Markovtsev's avatar Vadim Markovtsev
Browse files

Swivel: move the rest of the ops to GPU

parent 89bccc63
......@@ -207,8 +207,6 @@ class SwivelModel(object):
sys.stdout.flush()
# ===== CREATE VARIABLES ======
with tf.device('/cpu:0'):
# embeddings
self.row_embedding = embeddings_with_init(
embedding_dim=config.embedding_size,
......@@ -224,25 +222,24 @@ class SwivelModel(object):
matrix_log_sum = math.log(np.sum(row_sums) + 1)
row_bias_init = [math.log(x + 1) for x in row_sums]
col_bias_init = [math.log(x + 1) for x in col_sums]
self.row_bias = tf.Variable(row_bias_init,
trainable=config.trainable_bias)
self.col_bias = tf.Variable(col_bias_init,
trainable=config.trainable_bias)
self.row_bias = tf.Variable(
row_bias_init, trainable=config.trainable_bias)
self.col_bias = tf.Variable(
col_bias_init, trainable=config.trainable_bias)
tf.summary.histogram('row_bias', self.row_bias)
tf.summary.histogram('col_bias', self.col_bias)
# ===== CREATE GRAPH =====
# Get input
with tf.device('/cpu:0'):
global_row, global_col, count = count_matrix_input(
count_matrix_files, config.submatrix_rows, config.submatrix_cols)
# Fetch embeddings.
selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding,
global_row)
selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding,
global_col)
selected_row_embedding = tf.nn.embedding_lookup(
self.row_embedding, global_row)
selected_col_embedding = tf.nn.embedding_lookup(
self.col_embedding, global_col)
# Fetch biases.
selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment