Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
705acc35
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "683ceafb02eb72b9c2c89441d4e8572aa675cf6e"
Commit
705acc35
authored
Jan 10, 2017
by
Christopher Shallue
Browse files
Replace deprecated functions
parent
f653bd23
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
10 deletions
+11
-10
im2txt/im2txt/ops/image_processing.py
im2txt/im2txt/ops/image_processing.py
+2
-2
im2txt/im2txt/ops/inputs.py
im2txt/im2txt/ops/inputs.py
+1
-1
im2txt/im2txt/show_and_tell_model.py
im2txt/im2txt/show_and_tell_model.py
+8
-7
No files found.
im2txt/im2txt/ops/image_processing.py
View file @
705acc35
...
@@ -128,6 +128,6 @@ def process_image(encoded_image,
...
@@ -128,6 +128,6 @@ def process_image(encoded_image,
image_summary
(
"final_image"
,
image
)
image_summary
(
"final_image"
,
image
)
# Rescale to [-1,1] instead of [0, 1]
# Rescale to [-1,1] instead of [0, 1]
image
=
tf
.
sub
(
image
,
0.5
)
image
=
tf
.
sub
tract
(
image
,
0.5
)
image
=
tf
.
mul
(
image
,
2.0
)
image
=
tf
.
mul
tiply
(
image
,
2.0
)
return
image
return
image
im2txt/im2txt/ops/inputs.py
View file @
705acc35
...
@@ -181,7 +181,7 @@ def batch_with_dynamic_pad(images_and_captions,
...
@@ -181,7 +181,7 @@ def batch_with_dynamic_pad(images_and_captions,
enqueue_list
=
[]
enqueue_list
=
[]
for
image
,
caption
in
images_and_captions
:
for
image
,
caption
in
images_and_captions
:
caption_length
=
tf
.
shape
(
caption
)[
0
]
caption_length
=
tf
.
shape
(
caption
)[
0
]
input_length
=
tf
.
expand_dims
(
tf
.
sub
(
caption_length
,
1
),
0
)
input_length
=
tf
.
expand_dims
(
tf
.
sub
tract
(
caption_length
,
1
),
0
)
input_seq
=
tf
.
slice
(
caption
,
[
0
],
input_length
)
input_seq
=
tf
.
slice
(
caption
,
[
0
],
input_length
)
target_seq
=
tf
.
slice
(
caption
,
[
1
],
input_length
)
target_seq
=
tf
.
slice
(
caption
,
[
1
],
input_length
)
...
...
im2txt/im2txt/show_and_tell_model.py
View file @
705acc35
...
@@ -244,10 +244,10 @@ class ShowAndTellModel(object):
...
@@ -244,10 +244,10 @@ class ShowAndTellModel(object):
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# new_c * sigmoid(o).
# new_c * sigmoid(o).
lstm_cell
=
tf
.
nn
.
rnn_cell
.
BasicLSTMCell
(
lstm_cell
=
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
num_units
=
self
.
config
.
num_lstm_units
,
state_is_tuple
=
True
)
num_units
=
self
.
config
.
num_lstm_units
,
state_is_tuple
=
True
)
if
self
.
mode
==
"train"
:
if
self
.
mode
==
"train"
:
lstm_cell
=
tf
.
nn
.
rnn_cell
.
DropoutWrapper
(
lstm_cell
=
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
,
lstm_cell
,
input_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
,
input_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
,
output_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
)
output_keep_prob
=
self
.
config
.
lstm_dropout_keep_prob
)
...
@@ -264,13 +264,13 @@ class ShowAndTellModel(object):
...
@@ -264,13 +264,13 @@ class ShowAndTellModel(object):
if
self
.
mode
==
"inference"
:
if
self
.
mode
==
"inference"
:
# In inference mode, use concatenated states for convenient feeding and
# In inference mode, use concatenated states for convenient feeding and
# fetching.
# fetching.
tf
.
concat
(
1
,
initial_state
,
name
=
"initial_state"
)
tf
.
concat
_v2
(
initial_state
,
1
,
name
=
"initial_state"
)
# Placeholder for feeding a batch of concatenated states.
# Placeholder for feeding a batch of concatenated states.
state_feed
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
state_feed
=
tf
.
placeholder
(
dtype
=
tf
.
float32
,
shape
=
[
None
,
sum
(
lstm_cell
.
state_size
)],
shape
=
[
None
,
sum
(
lstm_cell
.
state_size
)],
name
=
"state_feed"
)
name
=
"state_feed"
)
state_tuple
=
tf
.
split
(
1
,
2
,
state_feed
)
state_tuple
=
tf
.
split
(
value
=
state_feed
,
num_or_size_splits
=
2
,
axis
=
1
)
# Run a single LSTM step.
# Run a single LSTM step.
lstm_outputs
,
state_tuple
=
lstm_cell
(
lstm_outputs
,
state_tuple
=
lstm_cell
(
...
@@ -278,7 +278,7 @@ class ShowAndTellModel(object):
...
@@ -278,7 +278,7 @@ class ShowAndTellModel(object):
state
=
state_tuple
)
state
=
state_tuple
)
# Concatentate the resulting state.
# Concatentate the resulting state.
tf
.
concat
(
1
,
state_tuple
,
name
=
"state"
)
tf
.
concat
_v2
(
state_tuple
,
1
,
name
=
"state"
)
else
:
else
:
# Run the batch of sequence embeddings through the LSTM.
# Run the batch of sequence embeddings through the LSTM.
sequence_length
=
tf
.
reduce_sum
(
self
.
input_mask
,
1
)
sequence_length
=
tf
.
reduce_sum
(
self
.
input_mask
,
1
)
...
@@ -307,8 +307,9 @@ class ShowAndTellModel(object):
...
@@ -307,8 +307,9 @@ class ShowAndTellModel(object):
weights
=
tf
.
to_float
(
tf
.
reshape
(
self
.
input_mask
,
[
-
1
]))
weights
=
tf
.
to_float
(
tf
.
reshape
(
self
.
input_mask
,
[
-
1
]))
# Compute losses.
# Compute losses.
losses
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
targets
)
losses
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
targets
,
batch_loss
=
tf
.
div
(
tf
.
reduce_sum
(
tf
.
mul
(
losses
,
weights
)),
logits
=
logits
)
batch_loss
=
tf
.
div
(
tf
.
reduce_sum
(
tf
.
multiply
(
losses
,
weights
)),
tf
.
reduce_sum
(
weights
),
tf
.
reduce_sum
(
weights
),
name
=
"batch_loss"
)
name
=
"batch_loss"
)
tf
.
losses
.
add_loss
(
batch_loss
)
tf
.
losses
.
add_loss
(
batch_loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment