Commit 0d1b00b1 authored by Vadim Markovtsev's avatar Vadim Markovtsev
Browse files

Swivel: move the rest of the ops to GPU

parent 89bccc63
...@@ -207,8 +207,6 @@ class SwivelModel(object): ...@@ -207,8 +207,6 @@ class SwivelModel(object):
sys.stdout.flush() sys.stdout.flush()
# ===== CREATE VARIABLES ====== # ===== CREATE VARIABLES ======
with tf.device('/cpu:0'):
# embeddings # embeddings
self.row_embedding = embeddings_with_init( self.row_embedding = embeddings_with_init(
embedding_dim=config.embedding_size, embedding_dim=config.embedding_size,
...@@ -224,25 +222,24 @@ class SwivelModel(object): ...@@ -224,25 +222,24 @@ class SwivelModel(object):
matrix_log_sum = math.log(np.sum(row_sums) + 1) matrix_log_sum = math.log(np.sum(row_sums) + 1)
row_bias_init = [math.log(x + 1) for x in row_sums] row_bias_init = [math.log(x + 1) for x in row_sums]
col_bias_init = [math.log(x + 1) for x in col_sums] col_bias_init = [math.log(x + 1) for x in col_sums]
self.row_bias = tf.Variable(row_bias_init, self.row_bias = tf.Variable(
trainable=config.trainable_bias) row_bias_init, trainable=config.trainable_bias)
self.col_bias = tf.Variable(col_bias_init, self.col_bias = tf.Variable(
trainable=config.trainable_bias) col_bias_init, trainable=config.trainable_bias)
tf.summary.histogram('row_bias', self.row_bias) tf.summary.histogram('row_bias', self.row_bias)
tf.summary.histogram('col_bias', self.col_bias) tf.summary.histogram('col_bias', self.col_bias)
# ===== CREATE GRAPH ===== # ===== CREATE GRAPH =====
# Get input # Get input
with tf.device('/cpu:0'):
global_row, global_col, count = count_matrix_input( global_row, global_col, count = count_matrix_input(
count_matrix_files, config.submatrix_rows, config.submatrix_cols) count_matrix_files, config.submatrix_rows, config.submatrix_cols)
# Fetch embeddings. # Fetch embeddings.
selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding, selected_row_embedding = tf.nn.embedding_lookup(
global_row) self.row_embedding, global_row)
selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding, selected_col_embedding = tf.nn.embedding_lookup(
global_col) self.col_embedding, global_col)
# Fetch biases. # Fetch biases.
selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row) selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment