Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0d1b00b1
Commit
0d1b00b1
authored
Mar 03, 2017
by
Vadim Markovtsev
Browse files
Swivel: move the rest of the ops to GPU
parent
89bccc63
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
36 deletions
+33
-36
swivel/swivel.py
swivel/swivel.py
+33
-36
No files found.
swivel/swivel.py
View file @
0d1b00b1
...
...
@@ -207,8 +207,6 @@ class SwivelModel(object):
sys
.
stdout
.
flush
()
# ===== CREATE VARIABLES ======
with
tf
.
device
(
'/cpu:0'
):
# embeddings
self
.
row_embedding
=
embeddings_with_init
(
embedding_dim
=
config
.
embedding_size
,
...
...
@@ -224,25 +222,24 @@ class SwivelModel(object):
matrix_log_sum
=
math
.
log
(
np
.
sum
(
row_sums
)
+
1
)
row_bias_init
=
[
math
.
log
(
x
+
1
)
for
x
in
row_sums
]
col_bias_init
=
[
math
.
log
(
x
+
1
)
for
x
in
col_sums
]
self
.
row_bias
=
tf
.
Variable
(
row_bias_init
,
trainable
=
config
.
trainable_bias
)
self
.
col_bias
=
tf
.
Variable
(
col_bias_init
,
trainable
=
config
.
trainable_bias
)
self
.
row_bias
=
tf
.
Variable
(
row_bias_init
,
trainable
=
config
.
trainable_bias
)
self
.
col_bias
=
tf
.
Variable
(
col_bias_init
,
trainable
=
config
.
trainable_bias
)
tf
.
summary
.
histogram
(
'row_bias'
,
self
.
row_bias
)
tf
.
summary
.
histogram
(
'col_bias'
,
self
.
col_bias
)
# ===== CREATE GRAPH =====
# Get input
with
tf
.
device
(
'/cpu:0'
):
global_row
,
global_col
,
count
=
count_matrix_input
(
count_matrix_files
,
config
.
submatrix_rows
,
config
.
submatrix_cols
)
# Fetch embeddings.
selected_row_embedding
=
tf
.
nn
.
embedding_lookup
(
self
.
row_embedding
,
global_row
)
selected_col_embedding
=
tf
.
nn
.
embedding_lookup
(
self
.
col_embedding
,
global_col
)
selected_row_embedding
=
tf
.
nn
.
embedding_lookup
(
self
.
row_embedding
,
global_row
)
selected_col_embedding
=
tf
.
nn
.
embedding_lookup
(
self
.
col_embedding
,
global_col
)
# Fetch biases.
selected_row_bias
=
tf
.
nn
.
embedding_lookup
([
self
.
row_bias
],
global_row
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment