Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7d16fc45
Unverified
Commit
7d16fc45
authored
Feb 27, 2018
by
Andrew M Dai
Committed by
GitHub
Feb 27, 2018
Browse files
Merge pull request #3486 from a-dai/master
Added new MaskGAN model.
parents
c5c6eaf2
87fae3f7
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3940 additions
and
0 deletions
+3940
-0
research/maskgan/models/bidirectional_vd.py
research/maskgan/models/bidirectional_vd.py
+116
-0
research/maskgan/models/bidirectional_zaremba.py
research/maskgan/models/bidirectional_zaremba.py
+83
-0
research/maskgan/models/cnn.py
research/maskgan/models/cnn.py
+93
-0
research/maskgan/models/critic_vd.py
research/maskgan/models/critic_vd.py
+107
-0
research/maskgan/models/evaluation_utils.py
research/maskgan/models/evaluation_utils.py
+280
-0
research/maskgan/models/feedforward.py
research/maskgan/models/feedforward.py
+97
-0
research/maskgan/models/rnn.py
research/maskgan/models/rnn.py
+210
-0
research/maskgan/models/rnn_nas.py
research/maskgan/models/rnn_nas.py
+233
-0
research/maskgan/models/rnn_vd.py
research/maskgan/models/rnn_vd.py
+117
-0
research/maskgan/models/rnn_zaremba.py
research/maskgan/models/rnn_zaremba.py
+195
-0
research/maskgan/models/rollout.py
research/maskgan/models/rollout.py
+383
-0
research/maskgan/models/seq2seq.py
research/maskgan/models/seq2seq.py
+277
-0
research/maskgan/models/seq2seq_nas.py
research/maskgan/models/seq2seq_nas.py
+332
-0
research/maskgan/models/seq2seq_vd.py
research/maskgan/models/seq2seq_vd.py
+608
-0
research/maskgan/models/seq2seq_zaremba.py
research/maskgan/models/seq2seq_zaremba.py
+305
-0
research/maskgan/nas_utils/__init__.py
research/maskgan/nas_utils/__init__.py
+0
-0
research/maskgan/nas_utils/configs.py
research/maskgan/nas_utils/configs.py
+46
-0
research/maskgan/nas_utils/custom_cell.py
research/maskgan/nas_utils/custom_cell.py
+166
-0
research/maskgan/nas_utils/variational_dropout.py
research/maskgan/nas_utils/variational_dropout.py
+61
-0
research/maskgan/pretrain_mask_gan.py
research/maskgan/pretrain_mask_gan.py
+231
-0
No files found.
research/maskgan/models/bidirectional_vd.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
regularization
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
,
initial_state
=
None
):
"""Define the Discriminator graph."""
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
dis_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_vd_keep_prob
)
cell_fwd
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
cell_bwd
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
# print initial_state
# print cell_fwd.zero_state(FLAGS.batch_size, tf.float32)
if
initial_state
:
state_fwd
=
[[
tf
.
identity
(
x
)
for
x
in
inner_initial_state
]
for
inner_initial_state
in
initial_state
]
state_bwd
=
cell_bwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
else
:
state_fwd
=
cell_fwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
state_bwd
=
cell_bwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
dis_vd_keep_prob
,
2
*
hparams
.
dis_rnn_size
)
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
rnn_inputs
=
tf
.
unstack
(
rnn_inputs
,
axis
=
1
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
outputs
,
_
,
_
=
tf
.
contrib
.
rnn
.
static_bidirectional_rnn
(
cell_fwd
,
cell_bwd
,
rnn_inputs
,
state_fwd
,
state_bwd
,
scope
=
vs
)
if
is_training
:
outputs
*=
output_mask
# Prediction is linear output for Discriminator.
predictions
=
tf
.
contrib
.
layers
.
linear
(
outputs
,
1
,
scope
=
vs
)
predictions
=
tf
.
transpose
(
predictions
,
[
1
,
0
,
2
])
if
FLAGS
.
baseline_method
==
'critic'
:
with
tf
.
variable_scope
(
'critic'
,
reuse
=
reuse
)
as
critic_scope
:
values
=
tf
.
contrib
.
layers
.
linear
(
outputs
,
1
,
scope
=
critic_scope
)
values
=
tf
.
transpose
(
values
,
[
1
,
0
,
2
])
return
tf
.
squeeze
(
predictions
,
axis
=
2
),
tf
.
squeeze
(
values
,
axis
=
2
)
else
:
return
tf
.
squeeze
(
predictions
,
axis
=
2
),
None
research/maskgan/models/bidirectional_zaremba.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the bidirectional Discriminator graph."""
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
FLAGS
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
FLAGS
.
keep_prob
)
cell_fwd
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
cell_bwd
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
state_fwd
=
cell_fwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
state_bwd
=
cell_bwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
rnn_inputs
=
tf
.
unstack
(
rnn_inputs
,
axis
=
1
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
outputs
,
_
,
_
=
tf
.
contrib
.
rnn
.
static_bidirectional_rnn
(
cell_fwd
,
cell_bwd
,
rnn_inputs
,
state_fwd
,
state_bwd
,
scope
=
vs
)
# Prediction is linear output for Discriminator.
predictions
=
tf
.
contrib
.
layers
.
linear
(
outputs
,
1
,
scope
=
vs
)
predictions
=
tf
.
transpose
(
predictions
,
[
1
,
0
,
2
])
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
research/maskgan/models/cnn.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple CNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph."""
del
is_training
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension."
)
with
tf
.
variable_scope
(
"gen/rnn"
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
"embedding"
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
dis_filter_sizes
=
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
15
,
20
]
with
tf
.
variable_scope
(
"dis"
,
reuse
=
reuse
):
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
"embedding"
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
cnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
# Create a convolution layer for each filter size
conv_outputs
=
[]
for
filter_size
in
dis_filter_sizes
:
with
tf
.
variable_scope
(
"conv-%s"
%
filter_size
):
# Convolution Layer
filter_shape
=
[
filter_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_num_filters
]
W
=
tf
.
get_variable
(
name
=
"W"
,
initializer
=
tf
.
truncated_normal
(
filter_shape
,
stddev
=
0.1
))
b
=
tf
.
get_variable
(
name
=
"b"
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
hparams
.
dis_num_filters
]))
conv
=
tf
.
nn
.
conv1d
(
cnn_inputs
,
W
,
stride
=
1
,
padding
=
"SAME"
,
name
=
"conv"
)
# Apply nonlinearity
h
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv
,
b
),
name
=
"relu"
)
conv_outputs
.
append
(
h
)
# Combine all the pooled features
dis_num_filters_total
=
hparams
.
dis_num_filters
*
len
(
dis_filter_sizes
)
h_conv
=
tf
.
concat
(
conv_outputs
,
axis
=
2
)
h_conv_flat
=
tf
.
reshape
(
h_conv
,
[
-
1
,
dis_num_filters_total
])
# Add dropout
with
tf
.
variable_scope
(
"dropout"
):
h_drop
=
tf
.
nn
.
dropout
(
h_conv_flat
,
FLAGS
.
keep_prob
)
with
tf
.
variable_scope
(
"fully_connected"
):
fc
=
tf
.
contrib
.
layers
.
fully_connected
(
h_drop
,
num_outputs
=
dis_num_filters_total
/
2
)
# Final (unnormalized) scores and predictions
with
tf
.
variable_scope
(
"output"
):
W
=
tf
.
get_variable
(
"W"
,
shape
=
[
dis_num_filters_total
/
2
,
1
],
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
())
b
=
tf
.
get_variable
(
name
=
"b"
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
1
]))
predictions
=
tf
.
nn
.
xw_plus_b
(
fc
,
W
,
b
,
name
=
"predictions"
)
predictions
=
tf
.
reshape
(
predictions
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
return
predictions
research/maskgan/models/critic_vd.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critic model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
regularization
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
critic_seq2seq_vd_derivative
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Critic graph which is derived from the seq2seq_vd
Discriminator. This will be initialized with the same parameters as the
language model and will share the forward RNN components with the
Discriminator. This estimates the V(s_t), where the state
s_t = x_0,...,x_t-1.
"""
assert
FLAGS
.
discriminator_model
==
'seq2seq_vd'
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
else
:
with
tf
.
variable_scope
(
'dis/decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
with
tf
.
variable_scope
(
'dis/decoder/rnn/multi_rnn_cell'
,
reuse
=
True
)
as
dis_scope
:
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
True
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
dis_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_vd_keep_prob
)
cell_critic
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
with
tf
.
variable_scope
(
'critic'
,
reuse
=
reuse
):
state_dis
=
cell_critic
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_rnn_size
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
values
=
[]
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
if
t
==
0
:
rnn_in
=
tf
.
zeros_like
(
rnn_inputs
[:,
0
])
else
:
rnn_in
=
rnn_inputs
[:,
t
-
1
]
rnn_out
,
state_dis
=
cell_critic
(
rnn_in
,
state_dis
,
scope
=
dis_scope
)
if
is_training
:
rnn_out
*=
output_mask
# Prediction is linear output for Discriminator.
value
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
values
.
append
(
value
)
values
=
tf
.
stack
(
values
,
axis
=
1
)
return
tf
.
squeeze
(
values
,
axis
=
2
)
research/maskgan/models/evaluation_utils.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation utilities."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
Counter
# Dependency imports
import
numpy
as
np
from
scipy.special
import
expit
import
tensorflow
as
tf
from
model_utils
import
helper
from
model_utils
import
n_gram
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
print_and_log_losses
(
log
,
step
,
is_present_rate
,
avg_dis_loss
,
avg_gen_loss
):
"""Prints and logs losses to the log file.
Args:
log: GFile for logs.
step: Global step.
is_present_rate: Current masking rate.
avg_dis_loss: List of Discriminator losses.
avg_gen_loss: List of Generator losses.
"""
print
(
'global_step: %d'
%
step
)
print
(
' is_present_rate: %.3f'
%
is_present_rate
)
print
(
' D train loss: %.5f'
%
np
.
mean
(
avg_dis_loss
))
print
(
' G train loss: %.5f'
%
np
.
mean
(
avg_gen_loss
))
log
.
write
(
'
\n
global_step: %d
\n
'
%
step
)
log
.
write
((
' is_present_rate: %.3f
\n
'
%
is_present_rate
))
log
.
write
(
' D train loss: %.5f
\n
'
%
np
.
mean
(
avg_dis_loss
))
log
.
write
(
' G train loss: %.5f
\n
'
%
np
.
mean
(
avg_gen_loss
))
def
print_and_log
(
log
,
id_to_word
,
sequence_eval
,
max_num_to_print
=
5
):
"""Helper function for printing and logging evaluated sequences."""
indices_arr
=
np
.
asarray
(
sequence_eval
)
samples
=
helper
.
convert_to_human_readable
(
id_to_word
,
indices_arr
,
max_num_to_print
)
for
i
,
sample
in
enumerate
(
samples
):
print
(
'Sample'
,
i
,
'. '
,
sample
)
log
.
write
(
'
\n
Sample '
+
str
(
i
)
+
'. '
+
sample
)
log
.
write
(
'
\n
'
)
print
(
'
\n
'
)
log
.
flush
()
return
samples
def
zip_seq_pred_crossent
(
id_to_word
,
sequences
,
predictions
,
cross_entropy
):
"""Zip together the sequences, predictions, cross entropy."""
indices
=
np
.
asarray
(
sequences
)
batch_of_metrics
=
[]
for
ind_batch
,
pred_batch
,
crossent_batch
in
zip
(
indices
,
predictions
,
cross_entropy
):
metrics
=
[]
for
index
,
pred
,
crossent
in
zip
(
ind_batch
,
pred_batch
,
crossent_batch
):
metrics
.
append
([
str
(
id_to_word
[
index
]),
pred
,
crossent
])
batch_of_metrics
.
append
(
metrics
)
return
batch_of_metrics
def
zip_metrics
(
indices
,
*
args
):
"""Zip together the indices matrices with the provided metrics matrices."""
batch_of_metrics
=
[]
for
metrics_batch
in
zip
(
indices
,
*
args
):
metrics
=
[]
for
m
in
zip
(
*
metrics_batch
):
metrics
.
append
(
m
)
batch_of_metrics
.
append
(
metrics
)
return
batch_of_metrics
def
print_formatted
(
present
,
id_to_word
,
log
,
batch_of_tuples
):
"""Print and log metrics."""
num_cols
=
len
(
batch_of_tuples
[
0
][
0
])
repeat_float_format
=
'{:<12.3f} '
repeat_str_format
=
'{:<13}'
format_str
=
''
.
join
(
[
'[{:<1}] {:<20}'
,
str
(
repeat_float_format
*
(
num_cols
-
1
))])
# TODO(liamfedus): Generalize the logging. This is sloppy.
header_format_str
=
''
.
join
(
[
'[{:<1}] {:<20}'
,
str
(
repeat_str_format
*
(
num_cols
-
1
))])
header_str
=
header_format_str
.
format
(
'p'
,
'Word'
,
'p(real)'
,
'log-perp'
,
'log(p(a))'
,
'r'
,
'R=V*(s)'
,
'b=V(s)'
,
'A(a,s)'
)
for
i
,
batch
in
enumerate
(
batch_of_tuples
):
print
(
' Sample: %d'
%
i
)
log
.
write
(
' Sample %d.
\n
'
%
i
)
print
(
' '
,
header_str
)
log
.
write
(
' '
+
str
(
header_str
)
+
'
\n
'
)
for
j
,
t
in
enumerate
(
batch
):
t
=
list
(
t
)
t
[
0
]
=
id_to_word
[
t
[
0
]]
buffer_str
=
format_str
.
format
(
int
(
present
[
i
][
j
]),
*
t
)
print
(
' '
,
buffer_str
)
log
.
write
(
' '
+
str
(
buffer_str
)
+
'
\n
'
)
log
.
flush
()
def
generate_RL_logs
(
sess
,
model
,
log
,
id_to_word
,
feed
):
"""Generate complete logs while running with REINFORCE."""
# Impute Sequences.
[
p
,
fake_sequence_eval
,
fake_predictions_eval
,
_
,
fake_cross_entropy_losses_eval
,
_
,
fake_log_probs_eval
,
fake_rewards_eval
,
fake_baselines_eval
,
cumulative_rewards_eval
,
fake_advantages_eval
,
]
=
sess
.
run
(
[
model
.
present
,
model
.
fake_sequence
,
model
.
fake_predictions
,
model
.
real_predictions
,
model
.
fake_cross_entropy_losses
,
model
.
fake_logits
,
model
.
fake_log_probs
,
model
.
fake_rewards
,
model
.
fake_baselines
,
model
.
cumulative_rewards
,
model
.
fake_advantages
,
],
feed_dict
=
feed
)
indices
=
np
.
asarray
(
fake_sequence_eval
)
# Convert Discriminator linear layer to probability.
fake_prob_eval
=
expit
(
fake_predictions_eval
)
# Add metrics.
fake_tuples
=
zip_metrics
(
indices
,
fake_prob_eval
,
fake_cross_entropy_losses_eval
,
fake_log_probs_eval
,
fake_rewards_eval
,
cumulative_rewards_eval
,
fake_baselines_eval
,
fake_advantages_eval
)
# real_tuples = zip_metrics(indices, )
# Print forward sequences.
tuples_to_print
=
fake_tuples
[:
FLAGS
.
max_num_to_print
]
print_formatted
(
p
,
id_to_word
,
log
,
tuples_to_print
)
print
(
'Samples'
)
log
.
write
(
'Samples
\n
'
)
samples
=
print_and_log
(
log
,
id_to_word
,
fake_sequence_eval
,
FLAGS
.
max_num_to_print
)
return
samples
def
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
feed
):
"""Impute Sequences using the model for a particular feed and send it to
logs."""
# Impute Sequences.
[
p
,
sequence_eval
,
fake_predictions_eval
,
fake_cross_entropy_losses_eval
,
fake_logits_eval
]
=
sess
.
run
(
[
model
.
present
,
model
.
fake_sequence
,
model
.
fake_predictions
,
model
.
fake_cross_entropy_losses
,
model
.
fake_logits
],
feed_dict
=
feed
)
# Convert Discriminator linear layer to probability.
fake_prob_eval
=
expit
(
fake_predictions_eval
)
# Forward Masked Tuples.
fake_tuples
=
zip_seq_pred_crossent
(
id_to_word
,
sequence_eval
,
fake_prob_eval
,
fake_cross_entropy_losses_eval
)
tuples_to_print
=
fake_tuples
[:
FLAGS
.
max_num_to_print
]
if
FLAGS
.
print_verbose
:
print
(
'fake_logits_eval'
)
print
(
fake_logits_eval
)
for
i
,
batch
in
enumerate
(
tuples_to_print
):
print
(
' Sample %d.'
%
i
)
log
.
write
(
' Sample %d.
\n
'
%
i
)
for
j
,
pred
in
enumerate
(
batch
):
buffer_str
=
(
'[{:<1}] {:<20} {:<7.3f} {:<7.3f}'
).
format
(
int
(
p
[
i
][
j
]),
pred
[
0
],
pred
[
1
],
pred
[
2
])
print
(
' '
,
buffer_str
)
log
.
write
(
' '
+
str
(
buffer_str
)
+
'
\n
'
)
log
.
flush
()
print
(
'Samples'
)
log
.
write
(
'Samples
\n
'
)
samples
=
print_and_log
(
log
,
id_to_word
,
sequence_eval
,
FLAGS
.
max_num_to_print
)
return
samples
def
create_merged_ngram_dictionaries
(
indices
,
n
):
"""Generate a single dictionary for the full batch.
Args:
indices: List of lists of indices.
n: Degree of n-grams.
Returns:
Dictionary of hashed(n-gram tuples) to counts in the batch of indices.
"""
ngram_dicts
=
[]
for
ind
in
indices
:
ngrams
=
n_gram
.
find_all_ngrams
(
ind
,
n
=
n
)
ngram_counts
=
n_gram
.
construct_ngrams_dict
(
ngrams
)
ngram_dicts
.
append
(
ngram_counts
)
merged_gen_dict
=
Counter
()
for
ngram_dict
in
ngram_dicts
:
merged_gen_dict
+=
Counter
(
ngram_dict
)
return
merged_gen_dict
def
sequence_ngram_evaluation
(
sess
,
sequence
,
log
,
feed
,
data_ngram_count
,
n
):
"""Calculates the percent of ngrams produced in the sequence is present in
data_ngram_count.
Args:
sess: tf.Session.
sequence: Sequence Tensor from the MaskGAN model.
log: gFile log.
feed: Feed to evaluate.
data_ngram_count: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
Returns:
avg_percent_captured: Percent of produced ngrams that appear in the
data_ngram_count.
"""
del
log
# Impute sequence.
[
sequence_eval
]
=
sess
.
run
([
sequence
],
feed_dict
=
feed
)
indices
=
sequence_eval
# Retrieve the counts across the batch of indices.
gen_ngram_counts
=
create_merged_ngram_dictionaries
(
indices
,
n
=
n
)
return
n_gram
.
percent_unique_ngrams_in_train
(
data_ngram_count
,
gen_ngram_counts
)
research/maskgan/models/feedforward.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple FNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph."""
del
is_training
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
"If you wish to share Discriminator/Generator embeddings, they must be"
" same dimension."
)
with
tf
.
variable_scope
(
"gen/rnn"
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
"embedding"
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
"dis"
,
reuse
=
reuse
):
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
"embedding"
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
embeddings
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
# Input matrices.
W
=
tf
.
get_variable
(
"W"
,
initializer
=
tf
.
truncated_normal
(
shape
=
[
3
*
hparams
.
dis_embedding_dim
,
hparams
.
dis_hidden_dim
],
stddev
=
0.1
))
b
=
tf
.
get_variable
(
"b"
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
hparams
.
dis_hidden_dim
]))
# Output matrices.
W_out
=
tf
.
get_variable
(
"W_out"
,
initializer
=
tf
.
truncated_normal
(
shape
=
[
hparams
.
dis_hidden_dim
,
1
],
stddev
=
0.1
))
b_out
=
tf
.
get_variable
(
"b_out"
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
1
]))
predictions
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
inp
=
embeddings
[:,
t
]
if
t
>
0
:
past_inp
=
tf
.
unstack
(
embeddings
[:,
0
:
t
],
axis
=
1
)
avg_past_inp
=
tf
.
add_n
(
past_inp
)
/
len
(
past_inp
)
else
:
avg_past_inp
=
tf
.
zeros_like
(
inp
)
if
t
<
FLAGS
.
sequence_length
:
future_inp
=
tf
.
unstack
(
embeddings
[:,
t
:],
axis
=
1
)
avg_future_inp
=
tf
.
add_n
(
future_inp
)
/
len
(
future_inp
)
else
:
avg_future_inp
=
tf
.
zeros_like
(
inp
)
# Cumulative input.
concat_inp
=
tf
.
concat
([
avg_past_inp
,
inp
,
avg_future_inp
],
axis
=
1
)
# Hidden activations.
hidden
=
tf
.
nn
.
relu
(
tf
.
nn
.
xw_plus_b
(
concat_inp
,
W
,
b
,
name
=
"scores"
))
# Add dropout
with
tf
.
variable_scope
(
"dropout"
):
hidden
=
tf
.
nn
.
dropout
(
hidden
,
FLAGS
.
keep_prob
)
# Output.
output
=
tf
.
nn
.
xw_plus_b
(
hidden
,
W_out
,
b_out
,
name
=
"output"
)
predictions
.
append
(
output
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
research/maskgan/models/rnn.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# ZoneoutWrapper.
from
regularization
import
zoneout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf
.
logging
.
warning
(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.'
)
init_scale
=
0.05
initializer
=
tf
.
random_uniform_initializer
(
-
init_scale
,
init_scale
)
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
,
initializer
=
initializer
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
hparams
.
gen_rnn_size
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
FLAGS
.
zoneout_drop_prob
>
0.0
:
def
attn_cell
():
return
zoneout
.
ZoneoutWrapper
(
lstm_cell
(),
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
cell_gen
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
initial_state
=
cell_gen
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
softmax_w
=
tf
.
get_variable
(
'softmax_w'
,
[
hparams
.
gen_rnn_size
,
FLAGS
.
vocab_size
])
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if
t
==
0
:
# Always provide the real input at t = 0.
state_gen
=
initial_state
rnn_inp
=
rnn_inputs
[:,
t
]
# If the target at the last time-step was present, read in the real.
# If the target at the last time-step was not present, read in the fake.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
# Use teacher forcing.
if
(
is_training
and
FLAGS
.
gen_training_strategy
==
'cross_entropy'
)
or
is_validating
:
rnn_inp
=
real_rnn_inp
else
:
# Note that targets_t-1 == inputs_(t)
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
logit
=
tf
.
matmul
(
rnn_out
,
softmax_w
)
+
softmax_b
# Real sample.
real
=
targets
[:,
t
]
# Fake sample.
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
# Output for Generator will either be generated or the target.
# If present: Return real.
# If not present: Return fake.
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
# Append to lists.
sequence
.
append
(
output
)
logits
.
append
(
logit
)
log_probs
.
append
(
log_prob
)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
rnn_inputs
[:,
t
]
# RNN.
rnn_out
,
real_state_gen
=
cell_gen
(
rnn_inp
,
real_state_gen
)
final_state
=
real_state_gen
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
tf
.
stack
(
logits
,
axis
=
1
),
tf
.
stack
(
log_probs
,
axis
=
1
),
initial_state
,
final_state
)
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph.
Args:
hparams: Hyperparameters for the MaskGAN.
FLAGS: Current flags.
sequence: [FLAGS.batch_size, FLAGS.sequence_length]
is_training:
reuse
Returns:
predictions:
"""
tf
.
logging
.
warning
(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.'
)
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
hparams
.
dis_rnn_size
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
FLAGS
.
zoneout_drop_prob
>
0.0
:
def
attn_cell
():
return
zoneout
.
ZoneoutWrapper
(
lstm_cell
(),
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
cell_dis
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
state_dis
=
cell_dis
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
predictions
=
[]
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
rnn_inputs
[:,
t
]
rnn_out
,
state_dis
=
cell_dis
(
rnn_in
,
state_dis
)
# Prediction is linear output for Discriminator.
pred
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
predictions
.
append
(
pred
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
research/maskgan/models/rnn_nas.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
tensorflow
as
tf
# NAS Code..
from
nas_utils
import
configs
from
nas_utils
import
custom_cell
from
nas_utils
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
get_config
():
return
configs
.
AlienConfig2
()
LSTMTuple
=
collections
.
namedtuple
(
'LSTMTuple'
,
[
'c'
,
'h'
])
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf
.
logging
.
info
(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.'
)
config
=
get_config
()
config
.
keep_prob
=
[
hparams
.
gen_nas_keep_prob_0
,
hparams
.
gen_nas_keep_prob_1
]
configs
.
print_config
(
config
)
init_scale
=
config
.
init_scale
initializer
=
tf
.
random_uniform_initializer
(
-
init_scale
,
init_scale
)
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
,
initializer
=
initializer
):
# Neural architecture search cell.
cell
=
custom_cell
.
Alien
(
config
.
hidden_size
)
if
is_training
:
[
h2h_masks
,
_
,
_
,
output_mask
]
=
variational_dropout
.
generate_variational_dropout_masks
(
hparams
,
config
.
keep_prob
)
else
:
output_mask
=
None
cell_gen
=
custom_cell
.
GenericMultiRNNCell
([
cell
]
*
config
.
num_layers
)
initial_state
=
cell_gen
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
softmax_w
=
tf
.
matrix_transpose
(
embedding
)
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if
t
==
0
:
# Always provide the real input at t = 0.
state_gen
=
initial_state
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
(
is_training
and
FLAGS
.
gen_training_strategy
==
'cross_entropy'
):
rnn_inp
=
real_rnn_inp
else
:
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
if
is_training
:
state_gen
=
list
(
state_gen
)
for
layer_num
,
per_layer_state
in
enumerate
(
state_gen
):
per_layer_state
=
LSTMTuple
(
per_layer_state
[
0
],
per_layer_state
[
1
]
*
h2h_masks
[
layer_num
])
state_gen
[
layer_num
]
=
per_layer_state
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
if
is_training
:
rnn_out
=
output_mask
*
rnn_out
logit
=
tf
.
matmul
(
rnn_out
,
softmax_w
)
+
softmax_b
# Real sample.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
rnn_inputs
[:,
t
]
# RNN.
rnn_out
,
real_state_gen
=
cell_gen
(
rnn_inp
,
real_state_gen
)
final_state
=
real_state_gen
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
tf
.
stack
(
logits
,
axis
=
1
),
tf
.
stack
(
log_probs
,
axis
=
1
),
initial_state
,
final_state
)
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph."""
tf
.
logging
.
info
(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.'
)
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
config
=
get_config
()
config
.
keep_prob
=
[
hparams
.
dis_nas_keep_prob_0
,
hparams
.
dis_nas_keep_prob_1
]
configs
.
print_config
(
config
)
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
# Neural architecture search cell.
cell
=
custom_cell
.
Alien
(
config
.
hidden_size
)
if
is_training
:
[
h2h_masks
,
_
,
_
,
output_mask
]
=
variational_dropout
.
generate_variational_dropout_masks
(
hparams
,
config
.
keep_prob
)
else
:
output_mask
=
None
cell_dis
=
custom_cell
.
GenericMultiRNNCell
([
cell
]
*
config
.
num_layers
)
state_dis
=
cell_dis
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
predictions
=
[]
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
rnn_inputs
[:,
t
]
if
is_training
:
state_dis
=
list
(
state_dis
)
for
layer_num
,
per_layer_state
in
enumerate
(
state_dis
):
per_layer_state
=
LSTMTuple
(
per_layer_state
[
0
],
per_layer_state
[
1
]
*
h2h_masks
[
layer_num
])
state_dis
[
layer_num
]
=
per_layer_state
# RNN.
rnn_out
,
state_dis
=
cell_dis
(
rnn_in
,
state_dis
)
if
is_training
:
rnn_out
=
output_mask
*
rnn_out
# Prediction is linear output for Discriminator.
pred
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
predictions
.
append
(
pred
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
research/maskgan/models/rnn_vd.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
regularization
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
,
initial_state
=
None
):
"""Define the Discriminator graph."""
tf
.
logging
.
info
(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.'
)
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
dis_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_vd_keep_prob
)
cell_dis
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
if
initial_state
:
state_dis
=
[[
tf
.
identity
(
x
)
for
x
in
inner_initial_state
]
for
inner_initial_state
in
initial_state
]
else
:
state_dis
=
cell_dis
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_rnn_size
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
predictions
,
rnn_outs
=
[],
[]
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
rnn_inputs
[:,
t
]
rnn_out
,
state_dis
=
cell_dis
(
rnn_in
,
state_dis
)
if
is_training
:
rnn_out
*=
output_mask
# Prediction is linear output for Discriminator.
pred
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
predictions
.
append
(
pred
)
rnn_outs
.
append
(
rnn_out
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
if
FLAGS
.
baseline_method
==
'critic'
:
with
tf
.
variable_scope
(
'critic'
,
reuse
=
reuse
)
as
critic_scope
:
rnn_outs
=
tf
.
stack
(
rnn_outs
,
axis
=
1
)
values
=
tf
.
contrib
.
layers
.
linear
(
rnn_outs
,
1
,
scope
=
critic_scope
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
),
tf
.
squeeze
(
values
,
axis
=
2
)
else
:
return
tf
.
squeeze
(
predictions
,
axis
=
2
),
None
research/maskgan/models/rnn_zaremba.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple RNN model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph.
G will now impute tokens that have been masked from the input seqeunce.
"""
tf
.
logging
.
warning
(
'Undirectional generative model is not a useful model for this MaskGAN '
'because future context is needed. Use only for debugging purposes.'
)
init_scale
=
0.05
initializer
=
tf
.
random_uniform_initializer
(
-
init_scale
,
init_scale
)
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
,
initializer
=
initializer
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
gen_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
FLAGS
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
FLAGS
.
keep_prob
)
cell_gen
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
initial_state
=
cell_gen
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
softmax_w
=
tf
.
get_variable
(
'softmax_w'
,
[
hparams
.
gen_rnn_size
,
FLAGS
.
vocab_size
])
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
fake
=
None
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the model is the first token to provide context. The
# model will then predict token t > 0.
if
t
==
0
:
# Always provide the real input at t = 0.
state_gen
=
initial_state
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
(
is_training
and
FLAGS
.
gen_training_strategy
==
'cross_entropy'
):
rnn_inp
=
real_rnn_inp
else
:
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
logit
=
tf
.
matmul
(
rnn_out
,
softmax_w
)
+
softmax_b
# Real sample.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
# Output for Generator will either be generated or the input.
#
# If present: Return real.
# If not present: Return fake.
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
# Produce the RNN state had the model operated only
# over real data.
real_state_gen
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
rnn_inputs
[:,
t
]
# RNN.
rnn_out
,
real_state_gen
=
cell_gen
(
rnn_inp
,
real_state_gen
)
final_state
=
real_state_gen
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
tf
.
stack
(
logits
,
axis
=
1
),
tf
.
stack
(
log_probs
,
axis
=
1
),
initial_state
,
final_state
)
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph."""
tf
.
logging
.
warning
(
'Undirectional Discriminative model is not a useful model for this '
'MaskGAN because future context is needed. Use only for debugging '
'purposes.'
)
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
FLAGS
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
FLAGS
.
keep_prob
)
cell_dis
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
state_dis
=
cell_dis
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
predictions
=
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
rnn_inputs
[:,
t
]
rnn_out
,
state_dis
=
cell_dis
(
rnn_in
,
state_dis
)
# Prediction is linear output for Discriminator.
pred
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
predictions
.
append
(
pred
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
research/maskgan/models/rollout.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Rollout RNN model definitions which call rnn_zaremba code."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
tensorflow
as
tf
from
losses
import
losses
from
model_utils
import
helper
from
model_utils
import
model_construction
from
model_utils
import
model_losses
from
model_utils
import
model_optimization
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
create_rollout_MaskGAN
(
hparams
,
is_training
):
"""Create the MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
is_training: Boolean indicating operational mode (train/inference).
evaluated with a teacher forcing regime.
Return:
model: Namedtuple for specifying the MaskGAN."""
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
new_learning_rate
=
tf
.
placeholder
(
tf
.
float32
,
[],
name
=
'new_learning_rate'
)
learning_rate
=
tf
.
Variable
(
0.0
,
name
=
'learning_rate'
,
trainable
=
False
)
learning_rate_update
=
tf
.
assign
(
learning_rate
,
new_learning_rate
)
new_rate
=
tf
.
placeholder
(
tf
.
float32
,
[],
name
=
'new_rate'
)
percent_real_var
=
tf
.
Variable
(
0.0
,
trainable
=
False
)
percent_real_update
=
tf
.
assign
(
percent_real_var
,
new_rate
)
## Placeholders.
inputs
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
present
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
inv_present
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
## Rollout Generator.
fwd_gen_rollouts
=
rollout_generator
(
hparams
,
inputs
,
present
,
is_training
=
is_training
,
is_validating
=
False
)
inv_gen_rollouts
=
rollout_generator
(
hparams
,
inputs
,
inv_present
,
is_training
=
is_training
,
is_validating
=
False
,
reuse
=
True
)
## Rollout Discriminator.
fwd_dis_rollouts
=
rollout_discriminator
(
hparams
,
fwd_gen_rollouts
,
is_training
=
is_training
)
inv_dis_rollouts
=
rollout_discriminator
(
hparams
,
inv_gen_rollouts
,
is_training
=
is_training
,
reuse
=
True
)
## Discriminator Loss.
[
dis_loss
,
dis_loss_pred
,
dis_loss_inv_pred
]
=
rollout_discriminator_loss
(
fwd_dis_rollouts
,
present
,
inv_dis_rollouts
,
inv_present
)
## Average log-perplexity for only missing words. However, to do this,
# the logits are still computed using teacher forcing, that is, the ground
# truth tokens are fed in at each time point to be valid.
# TODO(liamfedus): Fix the naming convention.
with
tf
.
variable_scope
(
'gen_rollout'
):
_
,
fwd_eval_logits
,
_
=
model_construction
.
create_generator
(
hparams
,
inputs
,
present
,
is_training
=
False
,
is_validating
=
True
,
reuse
=
True
)
avg_log_perplexity
=
model_losses
.
calculate_log_perplexity
(
fwd_eval_logits
,
inputs
,
present
)
## Generator Loss.
# 1. Cross Entropy losses on missing tokens.
[
fwd_cross_entropy_losses
,
inv_cross_entropy_losses
]
=
rollout_masked_cross_entropy_loss
(
inputs
,
present
,
inv_present
,
fwd_gen_rollouts
,
inv_gen_rollouts
)
# 2. GAN losses on missing tokens.
[
fwd_RL_loss
,
fwd_RL_statistics
,
fwd_averages_op
]
=
rollout_reinforce_objective
(
hparams
,
fwd_gen_rollouts
,
fwd_dis_rollouts
,
present
)
[
inv_RL_loss
,
inv_RL_statistics
,
inv_averages_op
]
=
rollout_reinforce_objective
(
hparams
,
inv_gen_rollouts
,
inv_dis_rollouts
,
inv_present
)
# TODO(liamfedus): Generalize this to use all logs.
[
fwd_sequence
,
fwd_logits
,
fwd_log_probs
]
=
fwd_gen_rollouts
[
-
1
]
[
inv_sequence
,
inv_logits
,
inv_log_probs
]
=
inv_gen_rollouts
[
-
1
]
# TODO(liamfedus): Generalize this to use all logs.
fwd_predictions
=
fwd_dis_rollouts
[
-
1
]
inv_predictions
=
inv_dis_rollouts
[
-
1
]
# TODO(liamfedus): Generalize this to use all logs.
[
fwd_log_probs
,
fwd_rewards
,
fwd_advantages
,
fwd_baselines
]
=
fwd_RL_statistics
[
-
1
]
[
inv_log_probs
,
inv_rewards
,
inv_advantages
,
inv_baselines
]
=
inv_RL_statistics
[
-
1
]
## Pre-training.
if
FLAGS
.
gen_pretrain_steps
:
# TODO(liamfedus): Rewrite this.
fwd_cross_entropy_loss
=
tf
.
reduce_mean
(
fwd_cross_entropy_losses
)
gen_pretrain_op
=
model_optimization
.
create_gen_pretrain_op
(
hparams
,
fwd_cross_entropy_loss
,
global_step
)
else
:
gen_pretrain_op
=
tf
.
no_op
(
'gen_pretrain_no_op'
)
if
FLAGS
.
dis_pretrain_steps
:
dis_pretrain_op
=
model_optimization
.
create_dis_pretrain_op
(
hparams
,
dis_loss
,
global_step
)
else
:
dis_pretrain_op
=
tf
.
no_op
(
'dis_pretrain_no_op'
)
## Generator Train Op.
# 1. Cross-Entropy.
if
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
gen_loss
=
tf
.
reduce_mean
(
fwd_cross_entropy_losses
+
inv_cross_entropy_losses
)
/
2.
[
gen_train_op
,
gen_grads
,
gen_vars
]
=
model_optimization
.
create_gen_train_op
(
hparams
,
learning_rate
,
gen_loss
,
global_step
,
mode
=
'MINIMIZE'
)
# 2. GAN (REINFORCE)
elif
FLAGS
.
gen_training_strategy
==
'reinforce'
:
gen_loss
=
(
fwd_RL_loss
+
inv_RL_loss
)
/
2.
[
gen_train_op
,
gen_grads
,
gen_vars
]
=
model_optimization
.
create_reinforce_gen_train_op
(
hparams
,
learning_rate
,
gen_loss
,
fwd_averages_op
,
inv_averages_op
,
global_step
)
else
:
raise
NotImplementedError
## Discriminator Train Op.
dis_train_op
,
dis_grads
,
dis_vars
=
model_optimization
.
create_dis_train_op
(
hparams
,
dis_loss
,
global_step
)
## Summaries.
with
tf
.
name_scope
(
'general'
):
tf
.
summary
.
scalar
(
'percent_real'
,
percent_real_var
)
tf
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
with
tf
.
name_scope
(
'generator_losses'
):
tf
.
summary
.
scalar
(
'gen_loss'
,
tf
.
reduce_mean
(
gen_loss
))
tf
.
summary
.
scalar
(
'gen_loss_fwd_cross_entropy'
,
tf
.
reduce_mean
(
fwd_cross_entropy_losses
))
tf
.
summary
.
scalar
(
'gen_loss_inv_cross_entropy'
,
tf
.
reduce_mean
(
inv_cross_entropy_losses
))
with
tf
.
name_scope
(
'REINFORCE'
):
with
tf
.
name_scope
(
'objective'
):
tf
.
summary
.
scalar
(
'fwd_RL_loss'
,
tf
.
reduce_mean
(
fwd_RL_loss
))
tf
.
summary
.
scalar
(
'inv_RL_loss'
,
tf
.
reduce_mean
(
inv_RL_loss
))
with
tf
.
name_scope
(
'rewards'
):
helper
.
variable_summaries
(
fwd_rewards
,
'fwd_rewards'
)
helper
.
variable_summaries
(
inv_rewards
,
'inv_rewards'
)
with
tf
.
name_scope
(
'advantages'
):
helper
.
variable_summaries
(
fwd_advantages
,
'fwd_advantages'
)
helper
.
variable_summaries
(
inv_advantages
,
'inv_advantages'
)
with
tf
.
name_scope
(
'baselines'
):
helper
.
variable_summaries
(
fwd_baselines
,
'fwd_baselines'
)
helper
.
variable_summaries
(
inv_baselines
,
'inv_baselines'
)
with
tf
.
name_scope
(
'log_probs'
):
helper
.
variable_summaries
(
fwd_log_probs
,
'fwd_log_probs'
)
helper
.
variable_summaries
(
inv_log_probs
,
'inv_log_probs'
)
with
tf
.
name_scope
(
'discriminator_losses'
):
tf
.
summary
.
scalar
(
'dis_loss'
,
dis_loss
)
tf
.
summary
.
scalar
(
'dis_loss_fwd_sequence'
,
dis_loss_pred
)
tf
.
summary
.
scalar
(
'dis_loss_inv_sequence'
,
dis_loss_inv_pred
)
with
tf
.
name_scope
(
'logits'
):
helper
.
variable_summaries
(
fwd_logits
,
'fwd_logits'
)
helper
.
variable_summaries
(
inv_logits
,
'inv_logits'
)
for
v
,
g
in
zip
(
gen_vars
,
gen_grads
):
helper
.
variable_summaries
(
v
,
v
.
op
.
name
)
helper
.
variable_summaries
(
g
,
'grad/'
+
v
.
op
.
name
)
for
v
,
g
in
zip
(
dis_vars
,
dis_grads
):
helper
.
variable_summaries
(
v
,
v
.
op
.
name
)
helper
.
variable_summaries
(
g
,
'grad/'
+
v
.
op
.
name
)
merge_summaries_op
=
tf
.
summary
.
merge_all
()
# Model saver.
saver
=
tf
.
train
.
Saver
(
keep_checkpoint_every_n_hours
=
1
,
max_to_keep
=
5
)
# Named tuple that captures elements of the MaskGAN model.
Model
=
collections
.
namedtuple
(
'Model'
,
[
'inputs'
,
'present'
,
'inv_present'
,
'percent_real_update'
,
'new_rate'
,
'fwd_sequence'
,
'fwd_logits'
,
'fwd_rewards'
,
'fwd_advantages'
,
'fwd_log_probs'
,
'fwd_predictions'
,
'fwd_cross_entropy_losses'
,
'inv_sequence'
,
'inv_logits'
,
'inv_rewards'
,
'inv_advantages'
,
'inv_log_probs'
,
'inv_predictions'
,
'inv_cross_entropy_losses'
,
'avg_log_perplexity'
,
'dis_loss'
,
'gen_loss'
,
'dis_train_op'
,
'gen_train_op'
,
'gen_pretrain_op'
,
'dis_pretrain_op'
,
'merge_summaries_op'
,
'global_step'
,
'new_learning_rate'
,
'learning_rate_update'
,
'saver'
])
model
=
Model
(
inputs
,
present
,
inv_present
,
percent_real_update
,
new_rate
,
fwd_sequence
,
fwd_logits
,
fwd_rewards
,
fwd_advantages
,
fwd_log_probs
,
fwd_predictions
,
fwd_cross_entropy_losses
,
inv_sequence
,
inv_logits
,
inv_rewards
,
inv_advantages
,
inv_log_probs
,
inv_predictions
,
inv_cross_entropy_losses
,
avg_log_perplexity
,
dis_loss
,
gen_loss
,
dis_train_op
,
gen_train_op
,
gen_pretrain_op
,
dis_pretrain_op
,
merge_summaries_op
,
global_step
,
new_learning_rate
,
learning_rate_update
,
saver
)
return
model
def
rollout_generator
(
hparams
,
inputs
,
input_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph which does rollouts.
G will now impute tokens that have been masked from the input seqeunce.
"""
rollouts
=
[]
with
tf
.
variable_scope
(
'gen_rollout'
):
for
n
in
xrange
(
FLAGS
.
num_rollouts
):
if
n
>
0
:
# TODO(liamfedus): Why is it necessary here to manually set reuse?
reuse
=
True
tf
.
get_variable_scope
().
reuse_variables
()
[
sequence
,
logits
,
log_probs
]
=
model_construction
.
create_generator
(
hparams
,
inputs
,
input_present
,
is_training
,
is_validating
,
reuse
=
reuse
)
rollouts
.
append
([
sequence
,
logits
,
log_probs
])
# Length assertion.
assert
len
(
rollouts
)
==
FLAGS
.
num_rollouts
return
rollouts
def
rollout_discriminator
(
hparams
,
gen_rollouts
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph which does rollouts.
G will now impute tokens that have been masked from the input seqeunce.
"""
rollout_predictions
=
[]
with
tf
.
variable_scope
(
'dis_rollout'
):
for
n
,
rollout
in
enumerate
(
gen_rollouts
):
if
n
>
0
:
# TODO(liamfedus): Why is it necessary here to manually set reuse?
reuse
=
True
tf
.
get_variable_scope
().
reuse_variables
()
[
sequence
,
_
,
_
]
=
rollout
predictions
=
model_construction
.
create_discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
# Predictions for each rollout.
rollout_predictions
.
append
(
predictions
)
# Length assertion.
assert
len
(
rollout_predictions
)
==
FLAGS
.
num_rollouts
return
rollout_predictions
def
rollout_reinforce_objective
(
hparams
,
gen_rollouts
,
dis_rollouts
,
present
):
cumulative_gen_objective
=
0.
cumulative_averages_op
=
[]
cumulative_statistics
=
[]
assert
len
(
gen_rollouts
)
==
len
(
dis_rollouts
)
for
gen_rollout
,
dis_rollout
in
zip
(
gen_rollouts
,
dis_rollouts
):
[
_
,
_
,
log_probs
]
=
gen_rollout
dis_predictions
=
dis_rollout
[
final_gen_objective
,
log_probs
,
rewards
,
advantages
,
baselines
,
maintain_averages_op
]
=
model_losses
.
calculate_reinforce_objective
(
hparams
,
log_probs
,
dis_predictions
,
present
)
# Accumulate results.
cumulative_gen_objective
+=
final_gen_objective
cumulative_averages_op
.
append
(
maintain_averages_op
)
cumulative_statistics
.
append
([
log_probs
,
rewards
,
advantages
,
baselines
])
# Group all the averaging operations.
cumulative_averages_op
=
tf
.
group
(
*
cumulative_averages_op
)
cumulative_gen_objective
/=
FLAGS
.
num_rollouts
[
log_probs
,
rewards
,
advantages
,
baselines
]
=
cumulative_statistics
[
-
1
]
# Length assertion.
assert
len
(
cumulative_statistics
)
==
FLAGS
.
num_rollouts
return
[
cumulative_gen_objective
,
cumulative_statistics
,
cumulative_averages_op
]
def
rollout_masked_cross_entropy_loss
(
inputs
,
present
,
inv_present
,
fwd_rollouts
,
inv_rollouts
):
cumulative_fwd_cross_entropy_losses
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
cumulative_inv_cross_entropy_losses
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
for
fwd_rollout
,
inv_rollout
in
zip
(
fwd_rollouts
,
inv_rollouts
):
[
_
,
fwd_logits
,
_
]
=
fwd_rollout
[
_
,
inv_logits
,
_
]
=
inv_rollout
[
fwd_cross_entropy_losses
,
inv_cross_entropy_losses
]
=
model_losses
.
create_masked_cross_entropy_loss
(
inputs
,
present
,
inv_present
,
fwd_logits
,
inv_logits
)
cumulative_fwd_cross_entropy_losses
=
tf
.
add
(
cumulative_fwd_cross_entropy_losses
,
fwd_cross_entropy_losses
)
cumulative_inv_cross_entropy_losses
=
tf
.
add
(
cumulative_inv_cross_entropy_losses
,
inv_cross_entropy_losses
)
return
[
cumulative_fwd_cross_entropy_losses
,
cumulative_inv_cross_entropy_losses
]
def
rollout_discriminator_loss
(
fwd_rollouts
,
present
,
inv_rollouts
,
inv_present
):
dis_loss
=
0
dis_loss_pred
=
0
dis_loss_inv_pred
=
0
for
fwd_predictions
,
inv_predictions
in
zip
(
fwd_rollouts
,
inv_rollouts
):
dis_loss_pred
+=
losses
.
discriminator_loss
(
fwd_predictions
,
present
)
dis_loss_inv_pred
+=
losses
.
discriminator_loss
(
inv_predictions
,
inv_present
)
dis_loss_pred
/=
FLAGS
.
num_rollouts
dis_loss_inv_pred
/=
FLAGS
.
num_rollouts
dis_loss
=
(
dis_loss_pred
+
dis_loss_inv_pred
)
/
2.
return
[
dis_loss
,
dis_loss_pred
,
dis_loss_inv_pred
]
research/maskgan/models/seq2seq.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
models
import
attention_utils
# ZoneoutWrapper.
from
regularization
import
zoneout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
transform_input_with_is_missing_token
(
inputs
,
targets_present
):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing
=
tf
.
constant
(
FLAGS
.
vocab_size
,
dtype
=
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
# The 0th input will always be present to MaskGAN.
zeroth_input_present
=
tf
.
constant
(
True
,
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
1
])
# Input present mask.
inputs_present
=
tf
.
concat
(
[
zeroth_input_present
,
targets_present
[:,
:
-
1
]],
axis
=
1
)
transformed_input
=
tf
.
where
(
inputs_present
,
inputs
,
input_missing
)
return
transformed_input
def
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
,
reuse
=
None
):
"""Define the Encoder graph."""
# We will use the same variable from the decoder.
if
FLAGS
.
seq2seq_share_embedding
:
with
tf
.
variable_scope
(
'decoder/rnn'
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
hparams
.
gen_rnn_size
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
FLAGS
.
zoneout_drop_prob
>
0.0
:
def
attn_cell
():
return
zoneout
.
ZoneoutWrapper
(
lstm_cell
(),
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
initial_state
=
cell
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
# Add a missing token for inputs not present.
real_inputs
=
inputs
masked_inputs
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
with
tf
.
variable_scope
(
'rnn'
):
hidden_states
=
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
+
1
,
hparams
.
gen_rnn_size
])
real_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
real_inputs
)
masked_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
masked_inputs
)
state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
masked_rnn_inputs
[:,
t
]
rnn_out
,
state
=
cell
(
rnn_inp
,
state
)
hidden_states
.
append
(
rnn_out
)
final_masked_state
=
state
hidden_states
=
tf
.
stack
(
hidden_states
,
axis
=
1
)
# Produce the RNN state had the model operated only
# over real data.
real_state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
# RNN.
rnn_inp
=
real_rnn_inputs
[:,
t
]
rnn_out
,
real_state
=
cell
(
rnn_inp
,
real_state
)
final_state
=
real_state
return
(
hidden_states
,
final_masked_state
),
initial_state
,
final_state
def
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoding_state
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size
=
hparams
.
gen_rnn_size
with
tf
.
variable_scope
(
'decoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
gen_decoder_rnn_size
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
FLAGS
.
zoneout_drop_prob
>
0.0
:
def
attn_cell
():
return
zoneout
.
ZoneoutWrapper
(
lstm_cell
(),
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
cell_gen
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
# Hidden encoder states.
hidden_vector_encodings
=
encoding_state
[
0
]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen
=
encoding_state
[
1
]
if
FLAGS
.
attention_option
is
not
None
:
(
attention_keys
,
attention_values
,
_
,
attention_construct_fn
)
=
attention_utils
.
prepare_attention
(
hidden_vector_encodings
,
FLAGS
.
attention_option
,
num_units
=
gen_decoder_rnn_size
,
reuse
=
reuse
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
gen_decoder_rnn_size
])
softmax_w
=
tf
.
get_variable
(
'softmax_w'
,
[
gen_decoder_rnn_size
,
FLAGS
.
vocab_size
])
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the Decoder.
if
t
==
0
:
# Always provide the real input at t = 0.
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
(
is_training
and
FLAGS
.
gen_training_strategy
==
'cross_entropy'
):
rnn_inp
=
real_rnn_inp
else
:
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
if
FLAGS
.
attention_option
is
not
None
:
rnn_out
=
attention_construct_fn
(
rnn_out
,
attention_keys
,
attention_values
)
# # TODO(liamfedus): Assert not "monotonic" attention_type.
# # TODO(liamfedus): FLAGS.attention_type.
# context_state = revised_attention_utils._empty_state()
# rnn_out, context_state = attention_construct_fn(
# rnn_out, attention_keys, attention_values, context_state, t)
logit
=
tf
.
matmul
(
rnn_out
,
softmax_w
)
+
softmax_b
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
tf
.
stack
(
logits
,
axis
=
1
),
tf
.
stack
(
log_probs
,
axis
=
1
))
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph."""
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
):
encoder_states
,
initial_state
,
final_state
=
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
=
is_training
,
reuse
=
reuse
)
stacked_sequence
,
stacked_logits
,
stacked_log_probs
=
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoder_states
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
return
(
stacked_sequence
,
stacked_logits
,
stacked_log_probs
,
initial_state
,
final_state
)
research/maskgan/models/seq2seq_nas.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
tensorflow
as
tf
from
models
import
attention_utils
# NAS Code..
from
nas_utils
import
configs
from
nas_utils
import
custom_cell
from
nas_utils
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
get_config
():
return
configs
.
AlienConfig2
()
LSTMTuple
=
collections
.
namedtuple
(
'LSTMTuple'
,
[
'c'
,
'h'
])
def
transform_input_with_is_missing_token
(
inputs
,
targets_present
):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing
=
tf
.
constant
(
FLAGS
.
vocab_size
,
dtype
=
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
# The 0th input will always be present to MaskGAN.
zeroth_input_present
=
tf
.
constant
(
True
,
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
1
])
# Input present mask.
inputs_present
=
tf
.
concat
(
[
zeroth_input_present
,
targets_present
[:,
:
-
1
]],
axis
=
1
)
transformed_input
=
tf
.
where
(
inputs_present
,
inputs
,
input_missing
)
return
transformed_input
def
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
,
reuse
=
None
):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
config
=
get_config
()
configs
.
print_config
(
config
)
# We will use the same variable from the decoder.
if
FLAGS
.
seq2seq_share_embedding
:
with
tf
.
variable_scope
(
'decoder/rnn'
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
reuse
):
# Neural architecture search cell.
cell
=
custom_cell
.
Alien
(
config
.
hidden_size
)
if
is_training
:
[
h2h_masks
,
h2i_masks
,
_
,
output_mask
]
=
variational_dropout
.
generate_variational_dropout_masks
(
hparams
,
config
.
keep_prob
)
else
:
h2i_masks
,
output_mask
=
None
,
None
cell
=
custom_cell
.
GenericMultiRNNCell
([
cell
]
*
config
.
num_layers
)
initial_state
=
cell
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
# Add a missing token for inputs not present.
real_inputs
=
inputs
masked_inputs
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
with
tf
.
variable_scope
(
'rnn'
):
hidden_states
=
[]
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
if
not
FLAGS
.
seq2seq_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
missing_embedding
=
tf
.
get_variable
(
'missing_embedding'
,
[
1
,
hparams
.
gen_rnn_size
])
embedding
=
tf
.
concat
([
embedding
,
missing_embedding
],
axis
=
0
)
real_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
real_inputs
)
masked_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
masked_inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
masked_rnn_inputs
=
tf
.
nn
.
dropout
(
masked_rnn_inputs
,
FLAGS
.
keep_prob
)
state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
masked_rnn_inputs
[:,
t
]
if
is_training
:
state
=
list
(
state
)
for
layer_num
,
per_layer_state
in
enumerate
(
state
):
per_layer_state
=
LSTMTuple
(
per_layer_state
[
0
],
per_layer_state
[
1
]
*
h2h_masks
[
layer_num
])
state
[
layer_num
]
=
per_layer_state
rnn_out
,
state
=
cell
(
rnn_inp
,
state
,
h2i_masks
)
if
is_training
:
rnn_out
=
output_mask
*
rnn_out
hidden_states
.
append
(
rnn_out
)
final_masked_state
=
state
hidden_states
=
tf
.
stack
(
hidden_states
,
axis
=
1
)
# Produce the RNN state had the model operated only
# over real data.
real_state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
# RNN.
rnn_inp
=
real_rnn_inputs
[:,
t
]
rnn_out
,
real_state
=
cell
(
rnn_inp
,
real_state
)
final_state
=
real_state
return
(
hidden_states
,
final_masked_state
),
initial_state
,
final_state
def
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoding_state
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
config
=
get_config
()
gen_decoder_rnn_size
=
hparams
.
gen_rnn_size
if
FLAGS
.
seq2seq_share_embedding
:
with
tf
.
variable_scope
(
'decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
gen_decoder_rnn_size
])
with
tf
.
variable_scope
(
'decoder'
,
reuse
=
reuse
):
# Neural architecture search cell.
cell
=
custom_cell
.
Alien
(
config
.
hidden_size
)
if
is_training
:
[
h2h_masks
,
_
,
_
,
output_mask
]
=
variational_dropout
.
generate_variational_dropout_masks
(
hparams
,
config
.
keep_prob
)
else
:
output_mask
=
None
cell_gen
=
custom_cell
.
GenericMultiRNNCell
([
cell
]
*
config
.
num_layers
)
# Hidden encoder states.
hidden_vector_encodings
=
encoding_state
[
0
]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen
=
encoding_state
[
1
]
if
FLAGS
.
attention_option
is
not
None
:
(
attention_keys
,
attention_values
,
_
,
attention_construct_fn
)
=
attention_utils
.
prepare_attention
(
hidden_vector_encodings
,
FLAGS
.
attention_option
,
num_units
=
gen_decoder_rnn_size
,
reuse
=
reuse
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
if
not
FLAGS
.
seq2seq_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
gen_decoder_rnn_size
])
softmax_w
=
tf
.
matrix_transpose
(
embedding
)
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the Decoder.
if
t
==
0
:
# Always provide the real input at t = 0.
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
(
is_training
and
FLAGS
.
gen_training_strategy
==
'cross_entropy'
):
rnn_inp
=
real_rnn_inp
else
:
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
if
is_training
:
state_gen
=
list
(
state_gen
)
for
layer_num
,
per_layer_state
in
enumerate
(
state_gen
):
per_layer_state
=
LSTMTuple
(
per_layer_state
[
0
],
per_layer_state
[
1
]
*
h2h_masks
[
layer_num
])
state_gen
[
layer_num
]
=
per_layer_state
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
if
is_training
:
rnn_out
=
output_mask
*
rnn_out
if
FLAGS
.
attention_option
is
not
None
:
rnn_out
=
attention_construct_fn
(
rnn_out
,
attention_keys
,
attention_values
)
# # TODO(liamfedus): Assert not "monotonic" attention_type.
# # TODO(liamfedus): FLAGS.attention_type.
# context_state = revised_attention_utils._empty_state()
# rnn_out, context_state = attention_construct_fn(
# rnn_out, attention_keys, attention_values, context_state, t)
logit
=
tf
.
matmul
(
rnn_out
,
softmax_w
)
+
softmax_b
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
tf
.
stack
(
logits
,
axis
=
1
),
tf
.
stack
(
log_probs
,
axis
=
1
))
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph."""
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
):
encoder_states
,
initial_state
,
final_state
=
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
=
is_training
,
reuse
=
reuse
)
stacked_sequence
,
stacked_logits
,
stacked_log_probs
=
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoder_states
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
return
(
stacked_sequence
,
stacked_logits
,
stacked_log_probs
,
initial_state
,
final_state
)
research/maskgan/models/seq2seq_vd.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
models
import
attention_utils
from
regularization
import
variational_dropout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
transform_input_with_is_missing_token
(
inputs
,
targets_present
):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
which computes,
inputs_present = [1, 1, 0, 1]
and outputs,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing
=
tf
.
constant
(
FLAGS
.
vocab_size
,
dtype
=
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
# The 0th input will always be present to MaskGAN.
zeroth_input_present
=
tf
.
constant
(
True
,
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
1
])
# Input present mask.
inputs_present
=
tf
.
concat
(
[
zeroth_input_present
,
targets_present
[:,
:
-
1
]],
axis
=
1
)
transformed_input
=
tf
.
where
(
inputs_present
,
inputs
,
input_missing
)
return
transformed_input
# TODO(adai): IMDB labels placeholder to encoder.
def
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
,
reuse
=
None
):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
# We will use the same variable from the decoder.
if
FLAGS
.
seq2seq_share_embedding
:
with
tf
.
variable_scope
(
'decoder/rnn'
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
gen_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
gen_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
gen_rnn_size
,
hparams
.
gen_vd_keep_prob
,
hparams
.
gen_vd_keep_prob
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
initial_state
=
cell
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
# Add a missing token for inputs not present.
real_inputs
=
inputs
masked_inputs
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
with
tf
.
variable_scope
(
'rnn'
)
as
scope
:
hidden_states
=
[]
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
if
not
FLAGS
.
seq2seq_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
missing_embedding
=
tf
.
get_variable
(
'missing_embedding'
,
[
1
,
hparams
.
gen_rnn_size
])
embedding
=
tf
.
concat
([
embedding
,
missing_embedding
],
axis
=
0
)
# TODO(adai): Perhaps append IMDB labels placeholder to input at
# each time point.
real_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
real_inputs
)
masked_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
masked_inputs
)
state
=
initial_state
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
1
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
gen_vd_keep_prob
,
hparams
.
gen_rnn_size
)
hidden_states
,
state
=
tf
.
nn
.
dynamic_rnn
(
cell
,
masked_rnn_inputs
,
initial_state
=
state
,
scope
=
scope
)
if
is_training
:
hidden_states
*=
output_mask
final_masked_state
=
state
# Produce the RNN state had the model operated only
# over real data.
real_state
=
initial_state
_
,
real_state
=
tf
.
nn
.
dynamic_rnn
(
cell
,
real_rnn_inputs
,
initial_state
=
real_state
,
scope
=
scope
)
final_state
=
real_state
return
(
hidden_states
,
final_masked_state
),
initial_state
,
final_state
# TODO(adai): IMDB labels placeholder to encoder.
def
gen_encoder_cnn
(
hparams
,
inputs
,
targets_present
,
is_training
,
reuse
=
None
):
"""Define the CNN Encoder graph."""
del
reuse
sequence
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
# TODO(liamfedus): Make this a hyperparameter.
dis_filter_sizes
=
[
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
15
,
20
]
# Keeping track of l2 regularization loss (optional)
# l2_loss = tf.constant(0.0)
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
True
):
with
tf
.
variable_scope
(
'rnn'
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
cnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
# Create a convolution layer for each filter size
conv_outputs
=
[]
for
filter_size
in
dis_filter_sizes
:
with
tf
.
variable_scope
(
'conv-%s'
%
filter_size
):
# Convolution Layer
filter_shape
=
[
filter_size
,
hparams
.
gen_rnn_size
,
hparams
.
dis_num_filters
]
W
=
tf
.
get_variable
(
name
=
'W'
,
initializer
=
tf
.
truncated_normal
(
filter_shape
,
stddev
=
0.1
))
b
=
tf
.
get_variable
(
name
=
'b'
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
hparams
.
dis_num_filters
]))
conv
=
tf
.
nn
.
conv1d
(
cnn_inputs
,
W
,
stride
=
1
,
padding
=
'SAME'
,
name
=
'conv'
)
# Apply nonlinearity
h
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv
,
b
),
name
=
'relu'
)
conv_outputs
.
append
(
h
)
# Combine all the pooled features
dis_num_filters_total
=
hparams
.
dis_num_filters
*
len
(
dis_filter_sizes
)
h_conv
=
tf
.
concat
(
conv_outputs
,
axis
=
2
)
h_conv_flat
=
tf
.
reshape
(
h_conv
,
[
-
1
,
dis_num_filters_total
])
# Add dropout
if
is_training
:
with
tf
.
variable_scope
(
'dropout'
):
h_conv_flat
=
tf
.
nn
.
dropout
(
h_conv_flat
,
hparams
.
gen_vd_keep_prob
)
# Final (unnormalized) scores and predictions
with
tf
.
variable_scope
(
'output'
):
W
=
tf
.
get_variable
(
'W'
,
shape
=
[
dis_num_filters_total
,
hparams
.
gen_rnn_size
],
initializer
=
tf
.
contrib
.
layers
.
xavier_initializer
())
b
=
tf
.
get_variable
(
name
=
'b'
,
initializer
=
tf
.
constant
(
0.1
,
shape
=
[
hparams
.
gen_rnn_size
]))
# l2_loss += tf.nn.l2_loss(W)
# l2_loss += tf.nn.l2_loss(b)
predictions
=
tf
.
nn
.
xw_plus_b
(
h_conv_flat
,
W
,
b
,
name
=
'predictions'
)
predictions
=
tf
.
reshape
(
predictions
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
hparams
.
gen_rnn_size
])
final_state
=
tf
.
reduce_mean
(
predictions
,
1
)
return
predictions
,
(
final_state
,
final_state
)
# TODO(adai): IMDB labels placeholder to decoder.
def
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoding_state
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size
=
hparams
.
gen_rnn_size
targets
=
tf
.
Print
(
targets
,
[
targets
],
message
=
'targets'
,
summarize
=
50
)
if
FLAGS
.
seq2seq_share_embedding
:
with
tf
.
variable_scope
(
'decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'decoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
gen_decoder_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
gen_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
gen_rnn_size
,
hparams
.
gen_vd_keep_prob
,
hparams
.
gen_vd_keep_prob
)
cell_gen
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
# Hidden encoder states.
hidden_vector_encodings
=
encoding_state
[
0
]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen
=
encoding_state
[
1
]
if
FLAGS
.
attention_option
is
not
None
:
(
attention_keys
,
attention_values
,
_
,
attention_construct_fn
)
=
attention_utils
.
prepare_attention
(
hidden_vector_encodings
,
FLAGS
.
attention_option
,
num_units
=
gen_decoder_rnn_size
,
reuse
=
reuse
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
gen_vd_keep_prob
,
hparams
.
gen_rnn_size
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
if
not
FLAGS
.
seq2seq_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
softmax_w
=
tf
.
matrix_transpose
(
embedding
)
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
# TODO(adai): Perhaps append IMDB labels placeholder to input at
# each time point.
rnn_outs
=
[]
fake
=
None
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the Decoder.
if
t
==
0
:
# Always provide the real input at t = 0.
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
rnn_inp
=
real_rnn_inp
else
:
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
if
FLAGS
.
attention_option
is
not
None
:
rnn_out
=
attention_construct_fn
(
rnn_out
,
attention_keys
,
attention_values
)
if
is_training
:
rnn_out
*=
output_mask
rnn_outs
.
append
(
rnn_out
)
if
FLAGS
.
gen_training_strategy
!=
'cross_entropy'
:
logit
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
rnn_out
,
softmax_w
),
softmax_b
)
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
if
FLAGS
.
use_gen_mode
:
fake
=
categorical
.
mode
()
else
:
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
else
:
real
=
targets
[:,
t
]
logit
=
tf
.
zeros
(
tf
.
stack
([
FLAGS
.
batch_size
,
FLAGS
.
vocab_size
]))
log_prob
=
tf
.
zeros
(
tf
.
stack
([
FLAGS
.
batch_size
]))
output
=
real
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
if
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
logits
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
tf
.
reshape
(
tf
.
stack
(
rnn_outs
,
1
),
[
-
1
,
gen_decoder_rnn_size
]),
softmax_w
),
softmax_b
)
logits
=
tf
.
reshape
(
logits
,
[
-
1
,
FLAGS
.
sequence_length
,
FLAGS
.
vocab_size
])
else
:
logits
=
tf
.
stack
(
logits
,
axis
=
1
)
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
logits
,
tf
.
stack
(
log_probs
,
axis
=
1
))
def
dis_encoder
(
hparams
,
masked_inputs
,
is_training
,
reuse
=
None
,
embedding
=
None
):
"""Define the Discriminator encoder. Reads in the masked inputs for context
and produces the hidden states of the encoder."""
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
dis_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_vd_keep_prob
)
cell_dis
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
state_dis
=
cell_dis
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
with
tf
.
variable_scope
(
'rnn'
):
hidden_states
=
[]
missing_embedding
=
tf
.
get_variable
(
'missing_embedding'
,
[
1
,
hparams
.
dis_rnn_size
])
embedding
=
tf
.
concat
([
embedding
,
missing_embedding
],
axis
=
0
)
masked_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
masked_inputs
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_rnn_size
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
masked_rnn_inputs
[:,
t
]
rnn_out
,
state_dis
=
cell_dis
(
rnn_in
,
state_dis
)
if
is_training
:
rnn_out
*=
output_mask
hidden_states
.
append
(
rnn_out
)
final_state
=
state_dis
return
(
tf
.
stack
(
hidden_states
,
axis
=
1
),
final_state
)
def
dis_decoder
(
hparams
,
sequence
,
encoding_state
,
is_training
,
reuse
=
None
,
embedding
=
None
):
"""Define the Discriminator decoder. Read in the sequence and predict
at each time point."""
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
with
tf
.
variable_scope
(
'decoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
hparams
.
dis_vd_keep_prob
<
1
:
def
attn_cell
():
return
variational_dropout
.
VariationalDropoutWrapper
(
lstm_cell
(),
FLAGS
.
batch_size
,
hparams
.
dis_rnn_size
,
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_vd_keep_prob
)
cell_dis
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
dis_num_layers
)],
state_is_tuple
=
True
)
# Hidden encoder states.
hidden_vector_encodings
=
encoding_state
[
0
]
# Carry forward the final state tuple from the encoder.
# State tuples.
state
=
encoding_state
[
1
]
if
FLAGS
.
attention_option
is
not
None
:
(
attention_keys
,
attention_values
,
_
,
attention_construct_fn
)
=
attention_utils
.
prepare_attention
(
hidden_vector_encodings
,
FLAGS
.
attention_option
,
num_units
=
hparams
.
dis_rnn_size
,
reuse
=
reuse
)
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
FLAGS
.
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
if
is_training
:
output_mask
=
make_mask
(
hparams
.
dis_vd_keep_prob
,
hparams
.
dis_rnn_size
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
predictions
=
[]
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_in
=
rnn_inputs
[:,
t
]
rnn_out
,
state
=
cell_dis
(
rnn_in
,
state
)
if
FLAGS
.
attention_option
is
not
None
:
rnn_out
=
attention_construct_fn
(
rnn_out
,
attention_keys
,
attention_values
)
if
is_training
:
rnn_out
*=
output_mask
# Prediction is linear output for Discriminator.
pred
=
tf
.
contrib
.
layers
.
linear
(
rnn_out
,
1
,
scope
=
vs
)
predictions
.
append
(
pred
)
predictions
=
tf
.
stack
(
predictions
,
axis
=
1
)
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
def
discriminator
(
hparams
,
inputs
,
targets_present
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the Discriminator graph."""
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/decoder/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
else
:
# Explicitly share the embedding.
with
tf
.
variable_scope
(
'dis/decoder/rnn'
,
reuse
=
reuse
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
# Mask the input sequence.
masked_inputs
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
# Confirm masking.
masked_inputs
=
tf
.
Print
(
masked_inputs
,
[
inputs
,
targets_present
,
masked_inputs
,
sequence
],
message
=
'inputs, targets_present, masked_inputs, sequence'
,
summarize
=
10
)
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
encoder_states
=
dis_encoder
(
hparams
,
masked_inputs
,
is_training
=
is_training
,
reuse
=
reuse
,
embedding
=
embedding
)
predictions
=
dis_decoder
(
hparams
,
sequence
,
encoder_states
,
is_training
=
is_training
,
reuse
=
reuse
,
embedding
=
embedding
)
# if FLAGS.baseline_method == 'critic':
# with tf.variable_scope('critic', reuse=reuse) as critic_scope:
# values = tf.contrib.layers.linear(rnn_outs, 1, scope=critic_scope)
# values = tf.squeeze(values, axis=2)
# else:
# values = None
return
predictions
# TODO(adai): IMDB labels placeholder to encoder/decoder.
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph."""
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
):
encoder_states
,
initial_state
,
final_state
=
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
=
is_training
,
reuse
=
reuse
)
stacked_sequence
,
stacked_logits
,
stacked_log_probs
=
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoder_states
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
return
(
stacked_sequence
,
stacked_logits
,
stacked_log_probs
,
initial_state
,
final_state
,
encoder_states
)
research/maskgan/models/seq2seq_zaremba.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple seq2seq model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
models
import
attention_utils
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
transform_input_with_is_missing_token
(
inputs
,
targets_present
):
"""Transforms the inputs to have missing tokens when it's masked out. The
mask is for the targets, so therefore, to determine if an input at time t is
masked, we have to check if the target at time t - 1 is masked out.
e.g.
inputs = [a, b, c, d]
targets = [b, c, d, e]
targets_present = [1, 0, 1, 0]
then,
transformed_input = [a, b, <missing>, d]
Args:
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the word.
Returns:
transformed_input: tf.int32 Tensor of shape [batch_size, sequence_length]
which takes on value of inputs when the input is present and takes on
value=vocab_size to indicate a missing token.
"""
# To fill in if the input is missing.
input_missing
=
tf
.
constant
(
FLAGS
.
vocab_size
,
dtype
=
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
# The 0th input will always be present to MaskGAN.
zeroth_input_present
=
tf
.
constant
(
True
,
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
1
])
# Input present mask.
inputs_present
=
tf
.
concat
(
[
zeroth_input_present
,
targets_present
[:,
:
-
1
]],
axis
=
1
)
transformed_input
=
tf
.
where
(
inputs_present
,
inputs
,
input_missing
)
return
transformed_input
def
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
,
reuse
=
None
):
"""Define the Encoder graph.
Args:
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
up to, but not including, vocab_size.
targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with
True representing the presence of the target.
is_training: Boolean indicating operational mode (train/inference).
reuse (Optional): Whether to reuse the variables.
Returns:
Tuple of (hidden_states, final_state).
"""
with
tf
.
variable_scope
(
'encoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
hparams
.
gen_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
FLAGS
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
FLAGS
.
keep_prob
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
initial_state
=
cell
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
# Add a missing token for inputs not present.
real_inputs
=
inputs
masked_inputs
=
transform_input_with_is_missing_token
(
inputs
,
targets_present
)
with
tf
.
variable_scope
(
'rnn'
):
hidden_states
=
[]
# Split the embedding into two parts so that we can load the PTB
# weights into one part of the Variable.
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
missing_embedding
=
tf
.
get_variable
(
'missing_embedding'
,
[
1
,
hparams
.
gen_rnn_size
])
embedding
=
tf
.
concat
([
embedding
,
missing_embedding
],
axis
=
0
)
real_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
real_inputs
)
masked_rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
masked_inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
masked_rnn_inputs
=
tf
.
nn
.
dropout
(
masked_rnn_inputs
,
FLAGS
.
keep_prob
)
state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
rnn_inp
=
masked_rnn_inputs
[:,
t
]
rnn_out
,
state
=
cell
(
rnn_inp
,
state
)
hidden_states
.
append
(
rnn_out
)
final_masked_state
=
state
hidden_states
=
tf
.
stack
(
hidden_states
,
axis
=
1
)
# Produce the RNN state had the model operated only
# over real data.
real_state
=
initial_state
for
t
in
xrange
(
FLAGS
.
sequence_length
):
tf
.
get_variable_scope
().
reuse_variables
()
# RNN.
rnn_inp
=
real_rnn_inputs
[:,
t
]
rnn_out
,
real_state
=
cell
(
rnn_inp
,
real_state
)
final_state
=
real_state
return
(
hidden_states
,
final_masked_state
),
initial_state
,
final_state
def
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoding_state
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Decoder graph. The Decoder will now impute tokens that
have been masked from the input seqeunce.
"""
gen_decoder_rnn_size
=
hparams
.
gen_rnn_size
with
tf
.
variable_scope
(
'decoder'
,
reuse
=
reuse
):
def
lstm_cell
():
return
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
gen_decoder_rnn_size
,
forget_bias
=
0.0
,
state_is_tuple
=
True
,
reuse
=
reuse
)
attn_cell
=
lstm_cell
if
is_training
and
FLAGS
.
keep_prob
<
1
:
def
attn_cell
():
return
tf
.
contrib
.
rnn
.
DropoutWrapper
(
lstm_cell
(),
output_keep_prob
=
FLAGS
.
keep_prob
)
cell_gen
=
tf
.
contrib
.
rnn
.
MultiRNNCell
(
[
attn_cell
()
for
_
in
range
(
hparams
.
gen_num_layers
)],
state_is_tuple
=
True
)
# Hidden encoder states.
hidden_vector_encodings
=
encoding_state
[
0
]
# Carry forward the final state tuple from the encoder.
# State tuples.
state_gen
=
encoding_state
[
1
]
if
FLAGS
.
attention_option
is
not
None
:
(
attention_keys
,
attention_values
,
_
,
attention_construct_fn
)
=
attention_utils
.
prepare_attention
(
hidden_vector_encodings
,
FLAGS
.
attention_option
,
num_units
=
gen_decoder_rnn_size
,
reuse
=
reuse
)
with
tf
.
variable_scope
(
'rnn'
):
sequence
,
logits
,
log_probs
=
[],
[],
[]
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
softmax_w
=
tf
.
matrix_transpose
(
embedding
)
softmax_b
=
tf
.
get_variable
(
'softmax_b'
,
[
FLAGS
.
vocab_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
inputs
)
if
is_training
and
FLAGS
.
keep_prob
<
1
:
rnn_inputs
=
tf
.
nn
.
dropout
(
rnn_inputs
,
FLAGS
.
keep_prob
)
rnn_outs
=
[]
fake
=
None
for
t
in
xrange
(
FLAGS
.
sequence_length
):
if
t
>
0
:
tf
.
get_variable_scope
().
reuse_variables
()
# Input to the Decoder.
if
t
==
0
:
# Always provide the real input at t = 0.
rnn_inp
=
rnn_inputs
[:,
t
]
# If the input is present, read in the input at t.
# If the input is not present, read in the previously generated.
else
:
real_rnn_inp
=
rnn_inputs
[:,
t
]
# While validating, the decoder should be operating in teacher
# forcing regime. Also, if we're just training with cross_entropy
# use teacher forcing.
if
is_validating
or
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
rnn_inp
=
real_rnn_inp
else
:
fake_rnn_inp
=
tf
.
nn
.
embedding_lookup
(
embedding
,
fake
)
rnn_inp
=
tf
.
where
(
targets_present
[:,
t
-
1
],
real_rnn_inp
,
fake_rnn_inp
)
# RNN.
rnn_out
,
state_gen
=
cell_gen
(
rnn_inp
,
state_gen
)
if
FLAGS
.
attention_option
is
not
None
:
rnn_out
=
attention_construct_fn
(
rnn_out
,
attention_keys
,
attention_values
)
rnn_outs
.
append
(
rnn_out
)
if
FLAGS
.
gen_training_strategy
!=
'cross_entropy'
:
logit
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
rnn_out
,
softmax_w
),
softmax_b
)
# Output for Decoder.
# If input is present: Return real at t+1.
# If input is not present: Return fake for t+1.
real
=
targets
[:,
t
]
categorical
=
tf
.
contrib
.
distributions
.
Categorical
(
logits
=
logit
)
fake
=
categorical
.
sample
()
log_prob
=
categorical
.
log_prob
(
fake
)
output
=
tf
.
where
(
targets_present
[:,
t
],
real
,
fake
)
else
:
batch_size
=
tf
.
shape
(
rnn_out
)[
0
]
logit
=
tf
.
zeros
(
tf
.
stack
([
batch_size
,
FLAGS
.
vocab_size
]))
log_prob
=
tf
.
zeros
(
tf
.
stack
([
batch_size
]))
output
=
targets
[:,
t
]
# Add to lists.
sequence
.
append
(
output
)
log_probs
.
append
(
log_prob
)
logits
.
append
(
logit
)
if
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
logits
=
tf
.
nn
.
bias_add
(
tf
.
matmul
(
tf
.
reshape
(
tf
.
stack
(
rnn_outs
,
1
),
[
-
1
,
gen_decoder_rnn_size
]),
softmax_w
),
softmax_b
)
logits
=
tf
.
reshape
(
logits
,
[
-
1
,
FLAGS
.
sequence_length
,
FLAGS
.
vocab_size
])
else
:
logits
=
tf
.
stack
(
logits
,
axis
=
1
)
return
(
tf
.
stack
(
sequence
,
axis
=
1
),
logits
,
tf
.
stack
(
log_probs
,
axis
=
1
))
def
generator
(
hparams
,
inputs
,
targets
,
targets_present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Define the Generator graph."""
with
tf
.
variable_scope
(
'gen'
,
reuse
=
reuse
):
encoder_states
,
initial_state
,
final_state
=
gen_encoder
(
hparams
,
inputs
,
targets_present
,
is_training
=
is_training
,
reuse
=
reuse
)
stacked_sequence
,
stacked_logits
,
stacked_log_probs
=
gen_decoder
(
hparams
,
inputs
,
targets
,
targets_present
,
encoder_states
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
return
(
stacked_sequence
,
stacked_logits
,
stacked_log_probs
,
initial_state
,
final_state
)
research/maskgan/nas_utils/__init__.py
0 → 100644
View file @
7d16fc45
research/maskgan/nas_utils/configs.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
print_config
(
config
):
print
(
"-"
*
10
,
"Configuration Specs"
,
"-"
*
10
)
for
item
in
dir
(
config
):
if
list
(
item
)[
0
]
!=
"_"
:
print
(
item
,
getattr
(
config
,
item
))
print
(
"-"
*
29
)
class
AlienConfig2
(
object
):
"""Base 8 740 shared embeddings, gets 64.0 (mean: std: min: max: )."""
init_scale
=
0.05
learning_rate
=
1.0
max_grad_norm
=
10
num_layers
=
2
num_steps
=
25
hidden_size
=
740
max_epoch
=
70
max_max_epoch
=
250
keep_prob
=
[
1
-
0.15
,
1
-
0.45
]
lr_decay
=
0.95
batch_size
=
20
vocab_size
=
10000
weight_decay
=
1e-4
share_embeddings
=
True
cell
=
"alien"
dropout_type
=
"variational"
research/maskgan/nas_utils/custom_cell.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
numpy
as
np
import
tensorflow
as
tf
flags
=
tf
.
flags
FLAGS
=
tf
.
app
.
flags
.
FLAGS
LSTMTuple
=
collections
.
namedtuple
(
'LSTMTuple'
,
[
'c'
,
'h'
])
def
cell_depth
(
num
):
num
/=
2
val
=
np
.
log2
(
1
+
num
)
assert
abs
(
val
-
int
(
val
))
==
0
return
int
(
val
)
class
GenericMultiRNNCell
(
tf
.
contrib
.
rnn
.
RNNCell
):
"""More generic version of MultiRNNCell that allows you to pass in a dropout mask"""
def
__init__
(
self
,
cells
):
"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
cells: list of RNNCells that will be composed in this order.
state_is_tuple: If True, accepted and returned states are n-tuples, where
`n = len(cells)`. If False, the states are all
concatenated along the column axis. This latter behavior will soon be
deprecated.
Raises:
ValueError: if cells is empty (not allowed), or at least one of the cells
returns a state tuple but the flag `state_is_tuple` is `False`.
"""
self
.
_cells
=
cells
@
property
def
state_size
(
self
):
return
tuple
(
cell
.
state_size
for
cell
in
self
.
_cells
)
@
property
def
output_size
(
self
):
return
self
.
_cells
[
-
1
].
output_size
def
__call__
(
self
,
inputs
,
state
,
input_masks
=
None
,
scope
=
None
):
"""Run this multi-layer cell on inputs, starting from state."""
with
tf
.
variable_scope
(
scope
or
type
(
self
).
__name__
):
cur_inp
=
inputs
new_states
=
[]
for
i
,
cell
in
enumerate
(
self
.
_cells
):
with
tf
.
variable_scope
(
'Cell%d'
%
i
):
cur_state
=
state
[
i
]
if
input_masks
is
not
None
:
cur_inp
*=
input_masks
[
i
]
cur_inp
,
new_state
=
cell
(
cur_inp
,
cur_state
)
new_states
.
append
(
new_state
)
new_states
=
tuple
(
new_states
)
return
cur_inp
,
new_states
class
AlienRNNBuilder
(
tf
.
contrib
.
rnn
.
RNNCell
):
def
__init__
(
self
,
num_units
,
params
,
additional_params
,
base_size
):
self
.
num_units
=
num_units
self
.
cell_create_index
=
additional_params
[
0
]
self
.
cell_inject_index
=
additional_params
[
1
]
self
.
base_size
=
base_size
self
.
cell_params
=
params
[
-
2
:]
# Cell injection parameters are always the last two
params
=
params
[:
-
2
]
self
.
depth
=
cell_depth
(
len
(
params
))
self
.
params
=
params
self
.
units_per_layer
=
[
2
**
i
for
i
in
range
(
self
.
depth
)
][::
-
1
]
# start with the biggest layer
def
__call__
(
self
,
inputs
,
state
,
scope
=
None
):
with
tf
.
variable_scope
(
scope
or
type
(
self
).
__name__
):
definition1
=
[
'add'
,
'elem_mult'
,
'max'
]
definition2
=
[
tf
.
identity
,
tf
.
tanh
,
tf
.
sigmoid
,
tf
.
nn
.
relu
,
tf
.
sin
]
layer_outputs
=
[[]
for
_
in
range
(
self
.
depth
)]
with
tf
.
variable_scope
(
'rnn_builder'
):
curr_index
=
0
c
,
h
=
state
# Run all dense matrix multiplications at once
big_h_mat
=
tf
.
get_variable
(
'big_h_mat'
,
[
self
.
num_units
,
self
.
base_size
*
self
.
num_units
],
tf
.
float32
)
big_inputs_mat
=
tf
.
get_variable
(
'big_inputs_mat'
,
[
self
.
num_units
,
self
.
base_size
*
self
.
num_units
],
tf
.
float32
)
big_h_output
=
tf
.
matmul
(
h
,
big_h_mat
)
big_inputs_output
=
tf
.
matmul
(
inputs
,
big_inputs_mat
)
h_splits
=
tf
.
split
(
big_h_output
,
self
.
base_size
,
axis
=
1
)
inputs_splits
=
tf
.
split
(
big_inputs_output
,
self
.
base_size
,
axis
=
1
)
for
layer_num
,
units
in
enumerate
(
self
.
units_per_layer
):
for
unit_num
in
range
(
units
):
with
tf
.
variable_scope
(
'layer_{}_unit_{}'
.
format
(
layer_num
,
unit_num
)):
if
layer_num
==
0
:
prev1_mat
=
h_splits
[
unit_num
]
prev2_mat
=
inputs_splits
[
unit_num
]
else
:
prev1_mat
=
layer_outputs
[
layer_num
-
1
][
2
*
unit_num
]
prev2_mat
=
layer_outputs
[
layer_num
-
1
][
2
*
unit_num
+
1
]
if
definition1
[
self
.
params
[
curr_index
]]
==
'add'
:
output
=
prev1_mat
+
prev2_mat
elif
definition1
[
self
.
params
[
curr_index
]]
==
'elem_mult'
:
output
=
prev1_mat
*
prev2_mat
elif
definition1
[
self
.
params
[
curr_index
]]
==
'max'
:
output
=
tf
.
maximum
(
prev1_mat
,
prev2_mat
)
if
curr_index
/
2
==
self
.
cell_create_index
:
# Take the new cell before the activation
new_c
=
tf
.
identity
(
output
)
output
=
definition2
[
self
.
params
[
curr_index
+
1
]](
output
)
if
curr_index
/
2
==
self
.
cell_inject_index
:
if
definition1
[
self
.
cell_params
[
0
]]
==
'add'
:
output
+=
c
elif
definition1
[
self
.
cell_params
[
0
]]
==
'elem_mult'
:
output
*=
c
elif
definition1
[
self
.
cell_params
[
0
]]
==
'max'
:
output
=
tf
.
maximum
(
output
,
c
)
output
=
definition2
[
self
.
cell_params
[
1
]](
output
)
layer_outputs
[
layer_num
].
append
(
output
)
curr_index
+=
2
new_h
=
layer_outputs
[
-
1
][
-
1
]
return
new_h
,
LSTMTuple
(
new_c
,
new_h
)
@
property
def
state_size
(
self
):
return
LSTMTuple
(
self
.
num_units
,
self
.
num_units
)
@
property
def
output_size
(
self
):
return
self
.
num_units
class
Alien
(
AlienRNNBuilder
):
"""Base 8 Cell."""
def
__init__
(
self
,
num_units
):
params
=
[
0
,
2
,
0
,
3
,
0
,
2
,
1
,
3
,
0
,
1
,
0
,
2
,
0
,
1
,
0
,
2
,
1
,
1
,
0
,
1
,
1
,
1
,
0
,
2
,
1
,
0
,
0
,
1
,
1
,
1
,
0
,
1
]
additional_params
=
[
12
,
8
]
base_size
=
8
super
(
Alien
,
self
).
__init__
(
num_units
,
params
,
additional_params
,
base_size
)
research/maskgan/nas_utils/variational_dropout.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variational Dropout."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
generate_dropout_masks
(
keep_prob
,
shape
,
amount
):
masks
=
[]
for
_
in
range
(
amount
):
dropout_mask
=
tf
.
random_uniform
(
shape
)
+
(
keep_prob
)
dropout_mask
=
tf
.
floor
(
dropout_mask
)
/
(
keep_prob
)
masks
.
append
(
dropout_mask
)
return
masks
def
generate_variational_dropout_masks
(
hparams
,
keep_prob
):
[
batch_size
,
num_steps
,
size
,
num_layers
]
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
hparams
.
gen_rnn_size
,
hparams
.
gen_num_layers
]
if
len
(
keep_prob
)
==
2
:
emb_keep_prob
=
keep_prob
[
0
]
# keep prob for embedding matrix
h2h_keep_prob
=
emb_keep_prob
# keep prob for hidden to hidden connections
h2i_keep_prob
=
keep_prob
[
1
]
# keep prob for hidden to input connections
out_keep_prob
=
h2i_keep_prob
# keep probability for output state
else
:
emb_keep_prob
=
keep_prob
[
0
]
# keep prob for embedding matrix
h2h_keep_prob
=
keep_prob
[
1
]
# keep prob for hidden to hidden connections
h2i_keep_prob
=
keep_prob
[
2
]
# keep prob for hidden to input connections
out_keep_prob
=
keep_prob
[
3
]
# keep probability for output state
h2i_masks
=
[]
# Masks for input to recurrent connections
h2h_masks
=
[]
# Masks for recurrent to recurrent connections
# Input word dropout mask
emb_masks
=
generate_dropout_masks
(
emb_keep_prob
,
[
num_steps
,
1
],
batch_size
)
output_mask
=
generate_dropout_masks
(
out_keep_prob
,
[
batch_size
,
size
],
1
)[
0
]
h2i_masks
=
generate_dropout_masks
(
h2i_keep_prob
,
[
batch_size
,
size
],
num_layers
)
h2h_masks
=
generate_dropout_masks
(
h2h_keep_prob
,
[
batch_size
,
size
],
num_layers
)
return
h2h_masks
,
h2i_masks
,
emb_masks
,
output_mask
research/maskgan/pretrain_mask_gan.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pretraining functions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
from
data
import
imdb_loader
from
data
import
ptb_loader
# Data.
from
model_utils
import
model_utils
from
models
import
evaluation_utils
tf
.
app
.
flags
.
DEFINE_integer
(
'gen_pretrain_steps'
,
None
,
'The number of steps to pretrain the generator with cross entropy loss.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'dis_pretrain_steps'
,
None
,
'The number of steps to pretrain the discriminator.'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
pretrain_generator
(
sv
,
sess
,
model
,
data
,
log
,
id_to_word
,
data_ngram_counts
,
is_chief
):
"""Pretrain the generator with classic language modeling training."""
print
(
'
\n
Pretraining generator for %d steps.'
%
FLAGS
.
gen_pretrain_steps
)
log
.
write
(
'
\n
Pretraining generator for %d steps.
\n
'
%
FLAGS
.
gen_pretrain_steps
)
is_pretraining
=
True
while
is_pretraining
:
costs
=
0.
iters
=
0
if
FLAGS
.
data_set
==
'ptb'
:
iterator
=
ptb_loader
.
ptb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
FLAGS
.
epoch_size_override
)
elif
FLAGS
.
data_set
==
'imdb'
:
iterator
=
imdb_loader
.
imdb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
)
for
x
,
y
,
_
in
iterator
:
# For pretraining with cross entropy loss, we have all tokens in the
# forward sequence present (all True).
model_utils
.
assign_percent_real
(
sess
,
model
.
percent_real_update
,
model
.
new_rate
,
1.0
)
p
=
np
.
ones
(
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
],
dtype
=
bool
)
pretrain_feed
=
{
model
.
inputs
:
x
,
model
.
targets
:
y
,
model
.
present
:
p
}
[
losses
,
cost_eval
,
_
,
step
]
=
sess
.
run
(
[
model
.
fake_cross_entropy_losses
,
model
.
avg_log_perplexity
,
model
.
gen_pretrain_op
,
model
.
global_step
],
feed_dict
=
pretrain_feed
)
costs
+=
cost_eval
iters
+=
FLAGS
.
sequence_length
# Calulate rolling perplexity.
perplexity
=
np
.
exp
(
costs
/
iters
)
# Summaries.
if
is_chief
and
step
%
FLAGS
.
summaries_every
==
0
:
# Graph summaries.
summary_str
=
sess
.
run
(
model
.
merge_summaries_op
,
feed_dict
=
pretrain_feed
)
sv
.
SummaryComputed
(
sess
,
summary_str
)
# Additional summary.
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
avg_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
pretrain_feed
,
data_ngram_count
,
int
(
n
))
summary_percent_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/%s-grams_percent_correct'
%
n
,
simple_value
=
avg_percent_captured
)
])
sv
.
SummaryComputed
(
sess
,
summary_percent_str
,
global_step
=
step
)
summary_perplexity_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/perplexity'
,
simple_value
=
perplexity
)
])
sv
.
SummaryComputed
(
sess
,
summary_perplexity_str
,
global_step
=
step
)
# Printing and logging
if
is_chief
and
step
%
FLAGS
.
print_every
==
0
:
print
(
'global_step: %d'
%
step
)
print
(
' generator loss: %.3f'
%
np
.
mean
(
losses
))
print
(
' perplexity: %.3f'
%
perplexity
)
log
.
write
(
'global_step: %d
\n
'
%
step
)
log
.
write
(
' generator loss: %.3f
\n
'
%
np
.
mean
(
losses
))
log
.
write
(
' perplexity: %.3f
\n
'
%
perplexity
)
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
avg_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
pretrain_feed
,
data_ngram_count
,
int
(
n
))
print
(
' percent of %s-grams captured: %.3f.
\n
'
%
(
n
,
avg_percent_captured
))
log
.
write
(
' percent of %s-grams captured: %.3f.
\n\n
'
%
(
n
,
avg_percent_captured
))
evaluation_utils
.
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
pretrain_feed
)
if
step
>=
FLAGS
.
gen_pretrain_steps
:
is_pretraining
=
False
break
return
def
pretrain_discriminator
(
sv
,
sess
,
model
,
data
,
log
,
id_to_word
,
data_ngram_counts
,
is_chief
):
print
(
'
\n
Pretraining discriminator for %d steps.'
%
FLAGS
.
dis_pretrain_steps
)
log
.
write
(
'
\n
Pretraining discriminator for %d steps.
\n
'
%
FLAGS
.
dis_pretrain_steps
)
is_pretraining
=
True
while
is_pretraining
:
cumulative_costs
=
0.
iters
=
0
if
FLAGS
.
data_set
==
'ptb'
:
iterator
=
ptb_loader
.
ptb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
FLAGS
.
epoch_size_override
)
elif
FLAGS
.
data_set
==
'imdb'
:
iterator
=
imdb_loader
.
imdb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
)
for
x
,
y
,
_
in
iterator
:
is_present_rate
=
FLAGS
.
is_present_rate
# is_present_rate = np.random.uniform(low=0.0, high=1.0)
model_utils
.
assign_percent_real
(
sess
,
model
.
percent_real_update
,
model
.
new_rate
,
is_present_rate
)
# Randomly mask out tokens.
p
=
model_utils
.
generate_mask
()
pretrain_feed
=
{
model
.
inputs
:
x
,
model
.
targets
:
y
,
model
.
present
:
p
}
[
_
,
dis_loss_eval
,
gen_log_perplexity_eval
,
step
]
=
sess
.
run
(
[
model
.
dis_pretrain_op
,
model
.
dis_loss
,
model
.
avg_log_perplexity
,
model
.
global_step
],
feed_dict
=
pretrain_feed
)
cumulative_costs
+=
gen_log_perplexity_eval
iters
+=
1
# Calulate rolling perplexity.
perplexity
=
np
.
exp
(
cumulative_costs
/
iters
)
# Summaries.
if
is_chief
and
step
%
FLAGS
.
summaries_every
==
0
:
# Graph summaries.
summary_str
=
sess
.
run
(
model
.
merge_summaries_op
,
feed_dict
=
pretrain_feed
)
sv
.
SummaryComputed
(
sess
,
summary_str
)
# Additional summary.
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
avg_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
pretrain_feed
,
data_ngram_count
,
int
(
n
))
summary_percent_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/%s-grams_percent_correct'
%
n
,
simple_value
=
avg_percent_captured
)
])
sv
.
SummaryComputed
(
sess
,
summary_percent_str
,
global_step
=
step
)
summary_perplexity_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/perplexity'
,
simple_value
=
perplexity
)
])
sv
.
SummaryComputed
(
sess
,
summary_perplexity_str
,
global_step
=
step
)
# Printing and logging
if
is_chief
and
step
%
FLAGS
.
print_every
==
0
:
print
(
'global_step: %d'
%
step
)
print
(
' discriminator loss: %.3f'
%
dis_loss_eval
)
print
(
' perplexity: %.3f'
%
perplexity
)
log
.
write
(
'global_step: %d
\n
'
%
step
)
log
.
write
(
' discriminator loss: %.3f
\n
'
%
dis_loss_eval
)
log
.
write
(
' perplexity: %.3f
\n
'
%
perplexity
)
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
avg_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
pretrain_feed
,
data_ngram_count
,
int
(
n
))
print
(
' percent of %s-grams captured: %.3f.
\n
'
%
(
n
,
avg_percent_captured
))
log
.
write
(
' percent of %s-grams captured: %.3f.
\n\n
'
%
(
n
,
avg_percent_captured
))
evaluation_utils
.
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
pretrain_feed
)
if
step
>=
FLAGS
.
dis_pretrain_steps
+
int
(
FLAGS
.
gen_pretrain_steps
or
0
):
is_pretraining
=
False
break
return
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment