Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b9665e9b
Commit
b9665e9b
authored
Sep 20, 2018
by
Menglong Zhu
Committed by
dreamdragon
Oct 24, 2018
Browse files
Allowing the option to produce LSTM outputs with the bottleneck feature map concatenated.
PiperOrigin-RevId: 213873057
parent
d0c1b9da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
8 deletions
+38
-8
research/lstm_object_detection/lstm/lstm_cells.py
research/lstm_object_detection/lstm/lstm_cells.py
+14
-7
research/lstm_object_detection/lstm/lstm_cells_test.py
research/lstm_object_detection/lstm/lstm_cells_test.py
+24
-1
No files found.
research/lstm_object_detection/lstm/lstm_cells.py
View file @
b9665e9b
...
@@ -45,19 +45,22 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
...
@@ -45,19 +45,22 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
forget_bias
=
1.0
,
forget_bias
=
1.0
,
activation
=
tf
.
tanh
,
activation
=
tf
.
tanh
,
flattened_state
=
False
,
flattened_state
=
False
,
output_bottleneck
=
False
,
visualize_gates
=
True
):
visualize_gates
=
True
):
"""Initializes the basic LSTM cell.
"""Initializes the basic LSTM cell.
Args:
Args:
filter_size: collection, conv filter size
filter_size: collection, conv filter size
.
output_size: collection, the width/height dimensions of the cell/output
output_size: collection, the width/height dimensions of the cell/output
.
num_units: int, The number of channels in the LSTM cell.
num_units: int, The number of channels in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
forget_bias: float, The bias added to forget gates (see above).
activation: Activation function of the inner states.
activation: Activation function of the inner states.
flattened_state: if True, state tensor will be flattened and stored as
flattened_state: if True, state tensor will be flattened and stored as
a 2-d tensor. Use for exporting the model to tfmini
a 2-d tensor. Use for exporting the model to tfmini.
output_bottleneck: if True, the cell bottleneck will be concatenated
to the cell output.
visualize_gates: if True, add histogram summaries of all gates
visualize_gates: if True, add histogram summaries of all gates
and outputs to tensorboard
and outputs to tensorboard
.
"""
"""
self
.
_filter_size
=
list
(
filter_size
)
self
.
_filter_size
=
list
(
filter_size
)
self
.
_output_size
=
list
(
output_size
)
self
.
_output_size
=
list
(
output_size
)
...
@@ -66,6 +69,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
...
@@ -66,6 +69,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
self
.
_activation
=
activation
self
.
_activation
=
activation
self
.
_viz_gates
=
visualize_gates
self
.
_viz_gates
=
visualize_gates
self
.
_flattened_state
=
flattened_state
self
.
_flattened_state
=
flattened_state
self
.
_output_bottleneck
=
output_bottleneck
self
.
_param_count
=
self
.
_num_units
self
.
_param_count
=
self
.
_num_units
for
dim
in
self
.
_output_size
:
for
dim
in
self
.
_output_size
:
self
.
_param_count
*=
dim
self
.
_param_count
*=
dim
...
@@ -99,7 +103,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
...
@@ -99,7 +103,7 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
with
tf
.
variable_scope
(
scope
):
with
tf
.
variable_scope
(
scope
):
c
,
h
=
state
c
,
h
=
state
# unflatten state if nec
c
esary
# unflatten state if nece
s
sary
if
self
.
_flattened_state
:
if
self
.
_flattened_state
:
c
=
tf
.
reshape
(
c
,
[
-
1
]
+
self
.
output_size
)
c
=
tf
.
reshape
(
c
,
[
-
1
]
+
self
.
output_size
)
h
=
tf
.
reshape
(
h
,
[
-
1
]
+
self
.
output_size
)
h
=
tf
.
reshape
(
h
,
[
-
1
]
+
self
.
output_size
)
...
@@ -140,13 +144,16 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
...
@@ -140,13 +144,16 @@ class BottleneckConvLSTMCell(tf.contrib.rnn.RNNCell):
slim
.
summaries
.
add_histogram_summary
(
new_h
,
'cell_output'
)
slim
.
summaries
.
add_histogram_summary
(
new_h
,
'cell_output'
)
slim
.
summaries
.
add_histogram_summary
(
new_c
,
'cell_state'
)
slim
.
summaries
.
add_histogram_summary
(
new_c
,
'cell_state'
)
output
=
new_h
if
self
.
_output_bottleneck
:
output
=
tf
.
concat
([
new_h
,
bottleneck
],
axis
=
3
)
# reflatten state to store it
# reflatten state to store it
if
self
.
_flattened_state
:
if
self
.
_flattened_state
:
new_c
=
tf
.
reshape
(
new_c
,
[
-
1
,
self
.
_param_count
])
new_c
=
tf
.
reshape
(
new_c
,
[
-
1
,
self
.
_param_count
])
new_h
=
tf
.
reshape
(
new_h
,
[
-
1
,
self
.
_param_count
])
new_h
=
tf
.
reshape
(
new_h
,
[
-
1
,
self
.
_param_count
])
return
new_h
,
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
return
output
,
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
new_c
,
new_h
)
new_c
,
new_h
if
self
.
_flattened_state
else
new_h
)
def
init_state
(
self
,
state_name
,
batch_size
,
dtype
,
learned_state
=
False
):
def
init_state
(
self
,
state_name
,
batch_size
,
dtype
,
learned_state
=
False
):
"""Creates an initial state compatible with this cell.
"""Creates an initial state compatible with this cell.
...
...
research/lstm_object_detection/lstm/lstm_cells_test.py
View file @
b9665e9b
...
@@ -66,10 +66,33 @@ class BottleneckConvLstmCellsTest(tf.test.TestCase):
...
@@ -66,10 +66,33 @@ class BottleneckConvLstmCellsTest(tf.test.TestCase):
init_state
=
cell
.
init_state
(
init_state
=
cell
.
init_state
(
state_name
,
batch_size
,
dtype
,
learned_state
)
state_name
,
batch_size
,
dtype
,
learned_state
)
output
,
state_tuple
=
cell
(
inputs
,
init_state
)
output
,
state_tuple
=
cell
(
inputs
,
init_state
)
self
.
assertAllEqual
([
4
,
1
500
],
output
.
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
1
0
,
10
,
15
],
output
.
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
1500
],
state_tuple
[
0
].
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
1500
],
state_tuple
[
0
].
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
1500
],
state_tuple
[
1
].
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
1500
],
state_tuple
[
1
].
shape
.
as_list
())
def
test_run_lstm_cell_with_output_bottleneck
(
self
):
filter_size
=
[
3
,
3
]
output_dim
=
10
output_size
=
[
output_dim
]
*
2
num_units
=
15
state_name
=
'lstm_state'
batch_size
=
4
dtype
=
tf
.
float32
learned_state
=
False
inputs
=
tf
.
zeros
([
batch_size
,
output_dim
,
output_dim
,
3
],
dtype
=
tf
.
float32
)
cell
=
lstm_cells
.
BottleneckConvLSTMCell
(
filter_size
=
filter_size
,
output_size
=
output_size
,
num_units
=
num_units
,
output_bottleneck
=
True
)
init_state
=
cell
.
init_state
(
state_name
,
batch_size
,
dtype
,
learned_state
)
output
,
state_tuple
=
cell
(
inputs
,
init_state
)
self
.
assertAllEqual
([
4
,
10
,
10
,
30
],
output
.
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
10
,
10
,
15
],
state_tuple
[
0
].
shape
.
as_list
())
self
.
assertAllEqual
([
4
,
10
,
10
,
15
],
state_tuple
[
1
].
shape
.
as_list
())
def
test_get_init_state
(
self
):
def
test_get_init_state
(
self
):
filter_size
=
[
3
,
3
]
filter_size
=
[
3
,
3
]
output_dim
=
10
output_dim
=
10
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment