Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
052e5e8b
Commit
052e5e8b
authored
Feb 23, 2017
by
Neal Wu
Browse files
Converted the models repo to TF 1.0 using the upgrade script
parent
f21c4278
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
54 additions
and
54 deletions
+54
-54
textsum/seq2seq_attention.py
textsum/seq2seq_attention.py
+2
-2
textsum/seq2seq_attention_model.py
textsum/seq2seq_attention_model.py
+10
-10
textsum/seq2seq_lib.py
textsum/seq2seq_lib.py
+1
-1
transformer/spatial_transformer.py
transformer/spatial_transformer.py
+9
-9
tutorials/embedding/word2vec.py
tutorials/embedding/word2vec.py
+1
-1
tutorials/image/cifar10/cifar10_multi_gpu_train.py
tutorials/image/cifar10/cifar10_multi_gpu_train.py
+1
-1
tutorials/rnn/ptb/ptb_word_lm.py
tutorials/rnn/ptb/ptb_word_lm.py
+1
-1
video_prediction/lstm_ops.py
video_prediction/lstm_ops.py
+6
-6
video_prediction/prediction_input.py
video_prediction/prediction_input.py
+3
-3
video_prediction/prediction_model.py
video_prediction/prediction_model.py
+10
-10
video_prediction/prediction_train.py
video_prediction/prediction_train.py
+10
-10
No files found.
textsum/seq2seq_attention.py
View file @
052e5e8b
...
...
@@ -86,7 +86,7 @@ def _Train(model, data_batcher):
saver
=
tf
.
train
.
Saver
()
# Train dir is different from log_root to avoid summary directory
# conflict with Supervisor.
summary_writer
=
tf
.
train
.
S
ummaryWriter
(
FLAGS
.
train_dir
)
summary_writer
=
tf
.
s
ummary
.
File
Writer
(
FLAGS
.
train_dir
)
sv
=
tf
.
train
.
Supervisor
(
logdir
=
FLAGS
.
log_root
,
is_chief
=
True
,
saver
=
saver
,
...
...
@@ -119,7 +119,7 @@ def _Eval(model, data_batcher, vocab=None):
"""Runs model eval."""
model
.
build_graph
()
saver
=
tf
.
train
.
Saver
()
summary_writer
=
tf
.
train
.
S
ummaryWriter
(
FLAGS
.
eval_dir
)
summary_writer
=
tf
.
s
ummary
.
File
Writer
(
FLAGS
.
eval_dir
)
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
running_avg_loss
=
0
step
=
0
...
...
textsum/seq2seq_attention_model.py
View file @
052e5e8b
...
...
@@ -139,10 +139,10 @@ class Seq2SeqAttentionModel(object):
vsize
=
self
.
_vocab
.
NumIds
()
with
tf
.
variable_scope
(
'seq2seq'
):
encoder_inputs
=
tf
.
un
p
ack
(
tf
.
transpose
(
self
.
_articles
))
decoder_inputs
=
tf
.
un
p
ack
(
tf
.
transpose
(
self
.
_abstracts
))
targets
=
tf
.
un
p
ack
(
tf
.
transpose
(
self
.
_targets
))
loss_weights
=
tf
.
un
p
ack
(
tf
.
transpose
(
self
.
_loss_weights
))
encoder_inputs
=
tf
.
un
st
ack
(
tf
.
transpose
(
self
.
_articles
))
decoder_inputs
=
tf
.
un
st
ack
(
tf
.
transpose
(
self
.
_abstracts
))
targets
=
tf
.
un
st
ack
(
tf
.
transpose
(
self
.
_targets
))
loss_weights
=
tf
.
un
st
ack
(
tf
.
transpose
(
self
.
_loss_weights
))
article_lens
=
self
.
_article_lens
# Embedding shared by the input and outputs.
...
...
@@ -195,7 +195,7 @@ class Seq2SeqAttentionModel(object):
encoder_outputs
=
[
tf
.
reshape
(
x
,
[
hps
.
batch_size
,
1
,
2
*
hps
.
num_hidden
])
for
x
in
encoder_outputs
]
self
.
_enc_top_states
=
tf
.
concat
(
1
,
encoder_outputs
)
self
.
_enc_top_states
=
tf
.
concat
(
axis
=
1
,
values
=
encoder_outputs
)
self
.
_dec_in_state
=
fw_state
# During decoding, follow up _dec_in_state are fed from beam_search.
# dec_out_state are stored by beam_search for next step feeding.
...
...
@@ -218,7 +218,7 @@ class Seq2SeqAttentionModel(object):
best_outputs
=
[
tf
.
argmax
(
x
,
1
)
for
x
in
model_outputs
]
tf
.
logging
.
info
(
'best_outputs%s'
,
best_outputs
[
0
].
get_shape
())
self
.
_outputs
=
tf
.
concat
(
1
,
[
tf
.
reshape
(
x
,
[
hps
.
batch_size
,
1
])
for
x
in
best_outputs
])
axis
=
1
,
values
=
[
tf
.
reshape
(
x
,
[
hps
.
batch_size
,
1
])
for
x
in
best_outputs
])
self
.
_topk_log_probs
,
self
.
_topk_ids
=
tf
.
nn
.
top_k
(
tf
.
log
(
tf
.
nn
.
softmax
(
model_outputs
[
-
1
])),
hps
.
batch_size
*
2
)
...
...
@@ -236,7 +236,7 @@ class Seq2SeqAttentionModel(object):
else
:
self
.
_loss
=
tf
.
nn
.
seq2seq
.
sequence_loss
(
model_outputs
,
targets
,
loss_weights
)
tf
.
s
calar_summary
(
'loss'
,
tf
.
minimum
(
12.0
,
self
.
_loss
))
tf
.
s
ummary
.
scalar
(
'loss'
,
tf
.
minimum
(
12.0
,
self
.
_loss
))
def
_add_train_op
(
self
):
"""Sets self._train_op, op to run for training."""
...
...
@@ -250,9 +250,9 @@ class Seq2SeqAttentionModel(object):
with
tf
.
device
(
self
.
_get_gpu
(
self
.
_num_gpus
-
1
)):
grads
,
global_norm
=
tf
.
clip_by_global_norm
(
tf
.
gradients
(
self
.
_loss
,
tvars
),
hps
.
max_grad_norm
)
tf
.
s
calar_summary
(
'global_norm'
,
global_norm
)
tf
.
s
ummary
.
scalar
(
'global_norm'
,
global_norm
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
self
.
_lr_rate
)
tf
.
s
calar_summary
(
'learning rate'
,
self
.
_lr_rate
)
tf
.
s
ummary
.
scalar
(
'learning rate'
,
self
.
_lr_rate
)
self
.
_train_op
=
optimizer
.
apply_gradients
(
zip
(
grads
,
tvars
),
global_step
=
self
.
global_step
,
name
=
'train_step'
)
...
...
@@ -296,4 +296,4 @@ class Seq2SeqAttentionModel(object):
self
.
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
if
self
.
_hps
.
mode
==
'train'
:
self
.
_add_train_op
()
self
.
_summaries
=
tf
.
merge_all
_summaries
()
self
.
_summaries
=
tf
.
summary
.
merge_all
()
textsum/seq2seq_lib.py
View file @
052e5e8b
...
...
@@ -127,7 +127,7 @@ def linear(args, output_size, bias, bias_start=0.0, scope=None):
if
len
(
args
)
==
1
:
res
=
tf
.
matmul
(
args
[
0
],
matrix
)
else
:
res
=
tf
.
matmul
(
tf
.
concat
(
1
,
args
),
matrix
)
res
=
tf
.
matmul
(
tf
.
concat
(
axis
=
1
,
values
=
args
),
matrix
)
if
not
bias
:
return
res
bias_term
=
tf
.
get_variable
(
...
...
transformer/spatial_transformer.py
View file @
052e5e8b
...
...
@@ -53,7 +53,7 @@ def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
def
_repeat
(
x
,
n_repeats
):
with
tf
.
variable_scope
(
'_repeat'
):
rep
=
tf
.
transpose
(
tf
.
expand_dims
(
tf
.
ones
(
shape
=
tf
.
p
ack
([
n_repeats
,
])),
1
),
[
1
,
0
])
tf
.
expand_dims
(
tf
.
ones
(
shape
=
tf
.
st
ack
([
n_repeats
,
])),
1
),
[
1
,
0
])
rep
=
tf
.
cast
(
rep
,
'int32'
)
x
=
tf
.
matmul
(
tf
.
reshape
(
x
,
(
-
1
,
1
)),
rep
)
return
tf
.
reshape
(
x
,
[
-
1
])
...
...
@@ -102,7 +102,7 @@ def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
# use indices to lookup pixels in the flat image and restore
# channels dim
im_flat
=
tf
.
reshape
(
im
,
tf
.
p
ack
([
-
1
,
channels
]))
im_flat
=
tf
.
reshape
(
im
,
tf
.
st
ack
([
-
1
,
channels
]))
im_flat
=
tf
.
cast
(
im_flat
,
'float32'
)
Ia
=
tf
.
gather
(
im_flat
,
idx_a
)
Ib
=
tf
.
gather
(
im_flat
,
idx_b
)
...
...
@@ -128,16 +128,16 @@ def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t
=
tf
.
matmul
(
tf
.
ones
(
shape
=
tf
.
p
ack
([
height
,
1
])),
x_t
=
tf
.
matmul
(
tf
.
ones
(
shape
=
tf
.
st
ack
([
height
,
1
])),
tf
.
transpose
(
tf
.
expand_dims
(
tf
.
linspace
(
-
1.0
,
1.0
,
width
),
1
),
[
1
,
0
]))
y_t
=
tf
.
matmul
(
tf
.
expand_dims
(
tf
.
linspace
(
-
1.0
,
1.0
,
height
),
1
),
tf
.
ones
(
shape
=
tf
.
p
ack
([
1
,
width
])))
tf
.
ones
(
shape
=
tf
.
st
ack
([
1
,
width
])))
x_t_flat
=
tf
.
reshape
(
x_t
,
(
1
,
-
1
))
y_t_flat
=
tf
.
reshape
(
y_t
,
(
1
,
-
1
))
ones
=
tf
.
ones_like
(
x_t_flat
)
grid
=
tf
.
concat
(
0
,
[
x_t_flat
,
y_t_flat
,
ones
])
grid
=
tf
.
concat
(
axis
=
0
,
values
=
[
x_t_flat
,
y_t_flat
,
ones
])
return
grid
def
_transform
(
theta
,
input_dim
,
out_size
):
...
...
@@ -157,11 +157,11 @@ def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
grid
=
_meshgrid
(
out_height
,
out_width
)
grid
=
tf
.
expand_dims
(
grid
,
0
)
grid
=
tf
.
reshape
(
grid
,
[
-
1
])
grid
=
tf
.
tile
(
grid
,
tf
.
p
ack
([
num_batch
]))
grid
=
tf
.
reshape
(
grid
,
tf
.
p
ack
([
num_batch
,
3
,
-
1
]))
grid
=
tf
.
tile
(
grid
,
tf
.
st
ack
([
num_batch
]))
grid
=
tf
.
reshape
(
grid
,
tf
.
st
ack
([
num_batch
,
3
,
-
1
]))
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g
=
tf
.
batch_
matmul
(
theta
,
grid
)
T_g
=
tf
.
matmul
(
theta
,
grid
)
x_s
=
tf
.
slice
(
T_g
,
[
0
,
0
,
0
],
[
-
1
,
1
,
-
1
])
y_s
=
tf
.
slice
(
T_g
,
[
0
,
1
,
0
],
[
-
1
,
1
,
-
1
])
x_s_flat
=
tf
.
reshape
(
x_s
,
[
-
1
])
...
...
@@ -172,7 +172,7 @@ def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
out_size
)
output
=
tf
.
reshape
(
input_transformed
,
tf
.
p
ack
([
num_batch
,
out_height
,
out_width
,
num_channels
]))
input_transformed
,
tf
.
st
ack
([
num_batch
,
out_height
,
out_width
,
num_channels
]))
return
output
with
tf
.
variable_scope
(
name
):
...
...
tutorials/embedding/word2vec.py
View file @
052e5e8b
...
...
@@ -246,7 +246,7 @@ class Word2Vec(object):
sampled_b
=
tf
.
nn
.
embedding_lookup
(
sm_b
,
sampled_ids
)
# True logits: [batch_size, 1]
true_logits
=
tf
.
reduce_sum
(
tf
.
mul
(
example_emb
,
true_w
),
1
)
+
true_b
true_logits
=
tf
.
reduce_sum
(
tf
.
mul
tiply
(
example_emb
,
true_w
),
1
)
+
true_b
# Sampled logits: [batch_size, num_sampled]
# We replicate sampled noise labels for all examples in the batch
...
...
tutorials/image/cifar10/cifar10_multi_gpu_train.py
View file @
052e5e8b
...
...
@@ -124,7 +124,7 @@ def average_gradients(tower_grads):
grads
.
append
(
expanded_g
)
# Average over the 'tower' dimension.
grad
=
tf
.
concat
(
grads
,
0
)
grad
=
tf
.
concat
(
axis
=
grads
,
values
=
0
)
grad
=
tf
.
reduce_mean
(
grad
,
0
)
# Keep in mind that the Variables are redundant because they are shared
...
...
tutorials/rnn/ptb/ptb_word_lm.py
View file @
052e5e8b
...
...
@@ -146,7 +146,7 @@ class PTBModel(object):
(
cell_output
,
state
)
=
cell
(
inputs
[:,
time_step
,
:],
state
)
outputs
.
append
(
cell_output
)
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
size
])
output
=
tf
.
reshape
(
tf
.
concat
(
axis
=
outputs
,
values
=
1
),
[
-
1
,
size
])
softmax_w
=
tf
.
get_variable
(
"softmax_w"
,
[
size
,
vocab_size
],
dtype
=
data_type
())
softmax_b
=
tf
.
get_variable
(
"softmax_b"
,
[
vocab_size
],
dtype
=
data_type
())
...
...
video_prediction/lstm_ops.py
View file @
052e5e8b
...
...
@@ -23,7 +23,7 @@ from tensorflow.contrib.slim import layers
def
init_state
(
inputs
,
state_shape
,
state_initializer
=
tf
.
zeros_initializer
,
state_initializer
=
tf
.
zeros_initializer
()
,
dtype
=
tf
.
float32
):
"""Helper function to create an initial state given inputs.
...
...
@@ -45,7 +45,7 @@ def init_state(inputs,
batch_size
=
0
initial_state
=
state_initializer
(
tf
.
p
ack
([
batch_size
]
+
state_shape
),
tf
.
st
ack
([
batch_size
]
+
state_shape
),
dtype
=
dtype
)
initial_state
.
set_shape
([
inferred_batch_size
]
+
state_shape
)
...
...
@@ -89,8 +89,8 @@ def basic_conv_lstm_cell(inputs,
reuse
=
reuse
):
inputs
.
get_shape
().
assert_has_rank
(
4
)
state
.
get_shape
().
assert_has_rank
(
4
)
c
,
h
=
tf
.
split
(
3
,
2
,
state
)
inputs_h
=
tf
.
concat
(
3
,
[
inputs
,
h
])
c
,
h
=
tf
.
split
(
axis
=
3
,
num_or_size_splits
=
2
,
value
=
state
)
inputs_h
=
tf
.
concat
(
axis
=
3
,
values
=
[
inputs
,
h
])
# Parameters of gates are concatenated into one conv for efficiency.
i_j_f_o
=
layers
.
conv2d
(
inputs_h
,
4
*
num_channels
,
[
filter_size
,
filter_size
],
...
...
@@ -99,12 +99,12 @@ def basic_conv_lstm_cell(inputs,
scope
=
'Gates'
)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i
,
j
,
f
,
o
=
tf
.
split
(
3
,
4
,
i_j_f_o
)
i
,
j
,
f
,
o
=
tf
.
split
(
axis
=
3
,
num_or_size_splits
=
4
,
value
=
i_j_f_o
)
new_c
=
c
*
tf
.
sigmoid
(
f
+
forget_bias
)
+
tf
.
sigmoid
(
i
)
*
tf
.
tanh
(
j
)
new_h
=
tf
.
tanh
(
new_c
)
*
tf
.
sigmoid
(
o
)
return
new_h
,
tf
.
concat
(
3
,
[
new_c
,
new_h
])
return
new_h
,
tf
.
concat
(
axis
=
3
,
values
=
[
new_c
,
new_h
])
video_prediction/prediction_input.py
View file @
052e5e8b
...
...
@@ -97,11 +97,11 @@ def build_tfrecord_input(training=True):
action
=
tf
.
reshape
(
features
[
action_name
],
shape
=
[
1
,
STATE_DIM
])
action_seq
.
append
(
action
)
image_seq
=
tf
.
concat
(
0
,
image_seq
)
image_seq
=
tf
.
concat
(
axis
=
0
,
values
=
image_seq
)
if
FLAGS
.
use_state
:
state_seq
=
tf
.
concat
(
0
,
state_seq
)
action_seq
=
tf
.
concat
(
0
,
action_seq
)
state_seq
=
tf
.
concat
(
axis
=
0
,
values
=
state_seq
)
action_seq
=
tf
.
concat
(
axis
=
0
,
values
=
action_seq
)
[
image_batch
,
action_batch
,
state_batch
]
=
tf
.
train
.
batch
(
[
image_seq
,
action_seq
,
state_seq
],
FLAGS
.
batch_size
,
...
...
video_prediction/prediction_model.py
View file @
052e5e8b
...
...
@@ -109,7 +109,7 @@ def construct_model(images,
prev_image
=
image
# Predicted state is always fed back in
state_action
=
tf
.
concat
(
1
,
[
action
,
current_state
])
state_action
=
tf
.
concat
(
axis
=
1
,
values
=
[
action
,
current_state
])
enc0
=
slim
.
layers
.
conv2d
(
prev_image
,
...
...
@@ -144,7 +144,7 @@ def construct_model(images,
smear
=
tf
.
tile
(
smear
,
[
1
,
int
(
enc2
.
get_shape
()[
1
]),
int
(
enc2
.
get_shape
()[
2
]),
1
])
if
use_state
:
enc2
=
tf
.
concat
(
3
,
[
enc2
,
smear
])
enc2
=
tf
.
concat
(
axis
=
3
,
values
=
[
enc2
,
smear
])
enc3
=
slim
.
layers
.
conv2d
(
enc2
,
hidden4
.
get_shape
()[
3
],
[
1
,
1
],
stride
=
1
,
scope
=
'conv4'
)
...
...
@@ -158,7 +158,7 @@ def construct_model(images,
enc4
,
lstm_state6
,
lstm_size
[
5
],
scope
=
'state6'
)
# 16x16
hidden6
=
tf_layers
.
layer_norm
(
hidden6
,
scope
=
'layer_norm7'
)
# Skip connection.
hidden6
=
tf
.
concat
(
3
,
[
hidden6
,
enc1
])
# both 16x16
hidden6
=
tf
.
concat
(
axis
=
3
,
values
=
[
hidden6
,
enc1
])
# both 16x16
enc5
=
slim
.
layers
.
conv2d_transpose
(
hidden6
,
hidden6
.
get_shape
()[
3
],
3
,
stride
=
2
,
scope
=
'convt2'
)
...
...
@@ -167,7 +167,7 @@ def construct_model(images,
hidden7
=
tf_layers
.
layer_norm
(
hidden7
,
scope
=
'layer_norm8'
)
# Skip connection.
hidden7
=
tf
.
concat
(
3
,
[
hidden7
,
enc0
])
# both 32x32
hidden7
=
tf
.
concat
(
axis
=
3
,
values
=
[
hidden7
,
enc0
])
# both 32x32
enc6
=
slim
.
layers
.
conv2d_transpose
(
hidden7
,
...
...
@@ -207,7 +207,7 @@ def construct_model(images,
masks
=
tf
.
reshape
(
tf
.
nn
.
softmax
(
tf
.
reshape
(
masks
,
[
-
1
,
num_masks
+
1
])),
[
int
(
batch_size
),
int
(
img_height
),
int
(
img_width
),
num_masks
+
1
])
mask_list
=
tf
.
split
(
3
,
num_masks
+
1
,
masks
)
mask_list
=
tf
.
split
(
axis
=
3
,
num_or_size_splits
=
num_masks
+
1
,
value
=
masks
)
output
=
mask_list
[
0
]
*
prev_image
for
layer
,
mask
in
zip
(
transformed
,
mask_list
[
1
:]):
output
+=
layer
*
mask
...
...
@@ -277,8 +277,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
cdna_kerns
/=
norm_factor
cdna_kerns
=
tf
.
tile
(
cdna_kerns
,
[
1
,
1
,
1
,
color_channels
,
1
])
cdna_kerns
=
tf
.
split
(
0
,
batch_size
,
cdna_kerns
)
prev_images
=
tf
.
split
(
0
,
batch_size
,
prev_image
)
cdna_kerns
=
tf
.
split
(
axis
=
0
,
num_or_size_splits
=
batch_size
,
value
=
cdna_kerns
)
prev_images
=
tf
.
split
(
axis
=
0
,
num_or_size_splits
=
batch_size
,
value
=
prev_image
)
# Transform image.
transformed
=
[]
...
...
@@ -288,8 +288,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
kernel
=
tf
.
expand_dims
(
kernel
,
-
1
)
transformed
.
append
(
tf
.
nn
.
depthwise_conv2d
(
preimg
,
kernel
,
[
1
,
1
,
1
,
1
],
'SAME'
))
transformed
=
tf
.
concat
(
0
,
transformed
)
transformed
=
tf
.
split
(
3
,
num_masks
,
transformed
)
transformed
=
tf
.
concat
(
axis
=
0
,
values
=
transformed
)
transformed
=
tf
.
split
(
axis
=
3
,
num_or_size_splits
=
num_masks
,
value
=
transformed
)
return
transformed
...
...
@@ -314,7 +314,7 @@ def dna_transformation(prev_image, dna_input):
tf
.
expand_dims
(
tf
.
slice
(
prev_image_pad
,
[
0
,
xkern
,
ykern
,
0
],
[
-
1
,
image_height
,
image_width
,
-
1
]),
[
3
]))
inputs
=
tf
.
concat
(
3
,
inputs
)
inputs
=
tf
.
concat
(
axis
=
3
,
values
=
inputs
)
# Normalize channels to 1.
kernel
=
tf
.
nn
.
relu
(
dna_input
-
RELU_SHIFT
)
+
RELU_SHIFT
...
...
video_prediction/prediction_train.py
View file @
052e5e8b
...
...
@@ -113,11 +113,11 @@ class Model(object):
summaries
=
[]
# Split into timesteps.
actions
=
tf
.
split
(
1
,
actions
.
get_shape
()[
1
],
actions
)
actions
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
actions
.
get_shape
()[
1
],
value
=
actions
)
actions
=
[
tf
.
squeeze
(
act
)
for
act
in
actions
]
states
=
tf
.
split
(
1
,
states
.
get_shape
()[
1
],
states
)
states
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
states
.
get_shape
()[
1
],
value
=
states
)
states
=
[
tf
.
squeeze
(
st
)
for
st
in
states
]
images
=
tf
.
split
(
1
,
images
.
get_shape
()[
1
],
images
)
images
=
tf
.
split
(
axis
=
1
,
num_or_size_splits
=
images
.
get_shape
()[
1
],
value
=
images
)
images
=
[
tf
.
squeeze
(
img
)
for
img
in
images
]
if
reuse_scope
is
None
:
...
...
@@ -157,8 +157,8 @@ class Model(object):
psnr_i
=
peak_signal_to_noise_ratio
(
x
,
gx
)
psnr_all
+=
psnr_i
summaries
.
append
(
tf
.
s
calar_summary
(
prefix
+
'_recon_cost'
+
str
(
i
),
recon_cost
))
summaries
.
append
(
tf
.
s
calar_summary
(
prefix
+
'_psnr'
+
str
(
i
),
psnr_i
))
tf
.
s
ummary
.
scalar
(
prefix
+
'_recon_cost'
+
str
(
i
),
recon_cost
))
summaries
.
append
(
tf
.
s
ummary
.
scalar
(
prefix
+
'_psnr'
+
str
(
i
),
psnr_i
))
loss
+=
recon_cost
for
i
,
state
,
gen_state
in
zip
(
...
...
@@ -166,19 +166,19 @@ class Model(object):
gen_states
[
FLAGS
.
context_frames
-
1
:]):
state_cost
=
mean_squared_error
(
state
,
gen_state
)
*
1e-4
summaries
.
append
(
tf
.
s
calar_summary
(
prefix
+
'_state_cost'
+
str
(
i
),
state_cost
))
tf
.
s
ummary
.
scalar
(
prefix
+
'_state_cost'
+
str
(
i
),
state_cost
))
loss
+=
state_cost
summaries
.
append
(
tf
.
s
calar_summary
(
prefix
+
'_psnr_all'
,
psnr_all
))
summaries
.
append
(
tf
.
s
ummary
.
scalar
(
prefix
+
'_psnr_all'
,
psnr_all
))
self
.
psnr_all
=
psnr_all
self
.
loss
=
loss
=
loss
/
np
.
float32
(
len
(
images
)
-
FLAGS
.
context_frames
)
summaries
.
append
(
tf
.
s
calar_summary
(
prefix
+
'_loss'
,
loss
))
summaries
.
append
(
tf
.
s
ummary
.
scalar
(
prefix
+
'_loss'
,
loss
))
self
.
lr
=
tf
.
placeholder_with_default
(
FLAGS
.
learning_rate
,
())
self
.
train_op
=
tf
.
train
.
AdamOptimizer
(
self
.
lr
).
minimize
(
loss
)
self
.
summ_op
=
tf
.
merge_
summary
(
summaries
)
self
.
summ_op
=
tf
.
summary
.
merge
(
summaries
)
def
main
(
unused_argv
):
...
...
@@ -200,7 +200,7 @@ def main(unused_argv):
# Make training session.
sess
=
tf
.
InteractiveSession
()
summary_writer
=
tf
.
train
.
S
ummaryWriter
(
summary_writer
=
tf
.
s
ummary
.
File
Writer
(
FLAGS
.
event_log_dir
,
graph
=
sess
.
graph
,
flush_secs
=
10
)
if
FLAGS
.
pretrained_model
:
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment