Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
87fae3f7
Commit
87fae3f7
authored
Feb 27, 2018
by
Andrew M. Dai
Browse files
Added new MaskGAN model.
parent
813dd09a
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1383 additions
and
0 deletions
+1383
-0
research/maskgan/regularization/__init__.py
research/maskgan/regularization/__init__.py
+0
-0
research/maskgan/regularization/variational_dropout.py
research/maskgan/regularization/variational_dropout.py
+56
-0
research/maskgan/regularization/zoneout.py
research/maskgan/regularization/zoneout.py
+64
-0
research/maskgan/sample_shuffler.py
research/maskgan/sample_shuffler.py
+95
-0
research/maskgan/train_mask_gan.py
research/maskgan/train_mask_gan.py
+1168
-0
No files found.
research/maskgan/regularization/__init__.py
0 → 100644
View file @
87fae3f7
research/maskgan/regularization/variational_dropout.py
0 → 100644
View file @
87fae3f7
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variational Dropout Wrapper."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
class
VariationalDropoutWrapper
(
tf
.
contrib
.
rnn
.
RNNCell
):
"""Add variational dropout to a RNN cell."""
def
__init__
(
self
,
cell
,
batch_size
,
input_size
,
recurrent_keep_prob
,
input_keep_prob
):
self
.
_cell
=
cell
self
.
_recurrent_keep_prob
=
recurrent_keep_prob
self
.
_input_keep_prob
=
input_keep_prob
def
make_mask
(
keep_prob
,
units
):
random_tensor
=
keep_prob
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
random_tensor
+=
tf
.
random_uniform
(
tf
.
stack
([
batch_size
,
units
]))
return
tf
.
floor
(
random_tensor
)
/
keep_prob
self
.
_recurrent_mask
=
make_mask
(
recurrent_keep_prob
,
self
.
_cell
.
state_size
[
0
])
self
.
_input_mask
=
self
.
_recurrent_mask
@
property
def
state_size
(
self
):
return
self
.
_cell
.
state_size
@
property
def
output_size
(
self
):
return
self
.
_cell
.
output_size
def
__call__
(
self
,
inputs
,
state
,
scope
=
None
):
dropped_inputs
=
inputs
*
self
.
_input_mask
dropped_state
=
(
state
[
0
],
state
[
1
]
*
self
.
_recurrent_mask
)
new_h
,
new_state
=
self
.
_cell
(
dropped_inputs
,
dropped_state
,
scope
)
return
new_h
,
new_state
research/maskgan/regularization/zoneout.py
0 → 100644
View file @
87fae3f7
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Zoneout Wrapper"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
class
ZoneoutWrapper
(
tf
.
contrib
.
rnn
.
RNNCell
):
"""Add Zoneout to a RNN cell."""
def
__init__
(
self
,
cell
,
zoneout_drop_prob
,
is_training
=
True
):
self
.
_cell
=
cell
self
.
_zoneout_prob
=
zoneout_drop_prob
self
.
_is_training
=
is_training
@
property
def
state_size
(
self
):
return
self
.
_cell
.
state_size
@
property
def
output_size
(
self
):
return
self
.
_cell
.
output_size
def
__call__
(
self
,
inputs
,
state
,
scope
=
None
):
output
,
new_state
=
self
.
_cell
(
inputs
,
state
,
scope
)
if
not
isinstance
(
self
.
_cell
.
state_size
,
tuple
):
new_state
=
tf
.
split
(
value
=
new_state
,
num_or_size_splits
=
2
,
axis
=
1
)
state
=
tf
.
split
(
value
=
state
,
num_or_size_splits
=
2
,
axis
=
1
)
final_new_state
=
[
new_state
[
0
],
new_state
[
1
]]
if
self
.
_is_training
:
for
i
,
state_element
in
enumerate
(
state
):
random_tensor
=
1
-
self
.
_zoneout_prob
# keep probability
random_tensor
+=
tf
.
random_uniform
(
tf
.
shape
(
state_element
))
# 0. if [zoneout_prob, 1.0) and 1. if [1.0, 1.0 + zoneout_prob)
binary_tensor
=
tf
.
floor
(
random_tensor
)
final_new_state
[
i
]
=
(
new_state
[
i
]
-
state_element
)
*
binary_tensor
+
state_element
else
:
for
i
,
state_element
in
enumerate
(
state
):
final_new_state
[
i
]
=
state_element
*
self
.
_zoneout_prob
+
new_state
[
i
]
*
(
1
-
self
.
_zoneout_prob
)
if
isinstance
(
self
.
_cell
.
state_size
,
tuple
):
return
output
,
tf
.
contrib
.
rnn
.
LSTMStateTuple
(
final_new_state
[
0
],
final_new_state
[
1
])
return
output
,
tf
.
concat
([
final_new_state
[
0
],
final_new_state
[
1
]],
1
)
research/maskgan/sample_shuffler.py
0 → 100644
View file @
87fae3f7
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shuffle samples for human evaluation.
Local launch command:
python sample_shuffler.py
--input_ml_path=/tmp/ptb/seq2seq_vd_shareemb_forreal_55_3
--input_gan_path=/tmp/ptb/MaskGAN_PTB_ari_avg_56.29_v2.0.0
--output_file_name=/tmp/ptb/shuffled_output.txt
python sample_shuffler.py
--input_ml_path=/tmp/generate_samples/MaskGAN_IMDB_Benchmark_87.1_v0.3.0
--input_gan_path=/tmp/generate_samples/MaskGAN_IMDB_v1.0.1
--output_file_name=/tmp/imdb/shuffled_output.txt
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
tf
.
app
.
flags
.
DEFINE_string
(
'input_ml_path'
,
'/tmp'
,
'Model output directory.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'input_gan_path'
,
'/tmp'
,
'Model output directory.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'output_file_name'
,
'/tmp/ptb/shuffled_output.txt'
,
'Model output file.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'output_masked_logs'
,
False
,
'Whether to display for human evaluation (show masking).'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'number_epochs'
,
1
,
'The number of epochs to produce.'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
shuffle_samples
(
input_file_1
,
input_file_2
):
"""Shuffle the examples."""
shuffled
=
[]
# Set a random seed to keep fixed mask.
np
.
random
.
seed
(
0
)
for
line_1
,
line_2
in
zip
(
input_file_1
,
input_file_2
):
rand
=
np
.
random
.
randint
(
1
,
3
)
if
rand
==
1
:
shuffled
.
append
((
rand
,
line_1
,
line_2
))
else
:
shuffled
.
append
((
rand
,
line_2
,
line_1
))
input_file_1
.
close
()
input_file_2
.
close
()
return
shuffled
def
generate_output
(
shuffled_tuples
,
output_file_name
):
output_file
=
tf
.
gfile
.
GFile
(
output_file_name
,
mode
=
'w'
)
for
tup
in
shuffled_tuples
:
formatted_tuple
=
(
'
\n
{:<1}, {:<1}, {:<1}'
).
format
(
tup
[
0
],
tup
[
1
].
rstrip
(),
tup
[
2
].
rstrip
())
output_file
.
write
(
formatted_tuple
)
output_file
.
close
()
def
main
(
_
):
ml_samples_file
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
input_ml_path
,
'reviews.txt'
),
mode
=
'r'
)
gan_samples_file
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
input_gan_path
,
'reviews.txt'
),
mode
=
'r'
)
# Generate shuffled tuples.
shuffled_tuples
=
shuffle_samples
(
ml_samples_file
,
gan_samples_file
)
# Output to file.
generate_output
(
shuffled_tuples
,
FLAGS
.
output_file_name
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/maskgan/train_mask_gan.py
0 → 100644
View file @
87fae3f7
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Launch example:
[IMDB]
python train_mask_gan.py --data_dir
/tmp/imdb --data_set imdb --batch_size 128
--sequence_length 20 --base_directory /tmp/maskGAN_v0.01
--hparams="gen_rnn_size=650,gen_num_layers=2,dis_rnn_size=650,dis_num_layers=2
,critic_learning_rate=0.0009756,dis_learning_rate=0.0000585,
dis_train_iterations=8,gen_learning_rate=0.0016624,
gen_full_learning_rate_steps=1e9,gen_learning_rate_decay=0.999999,
rl_discount_rate=0.8835659" --mode TRAIN --max_steps 1000000
--generator_model seq2seq_vd --discriminator_model seq2seq_vd
--is_present_rate 0.5 --summaries_every 25 --print_every 25
--max_num_to_print=3 --generator_optimizer=adam
--seq2seq_share_embedding=True --baseline_method=critic
--attention_option=luong --n_gram_eval=4 --mask_strategy=contiguous
--gen_training_strategy=reinforce --dis_pretrain_steps=100
--perplexity_threshold=1000000
--dis_share_embedding=True --maskgan_ckpt
/tmp/model.ckpt-171091
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
from
functools
import
partial
import
os
import
time
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
import
pretrain_mask_gan
from
data
import
imdb_loader
from
data
import
ptb_loader
from
model_utils
import
helper
from
model_utils
import
model_construction
from
model_utils
import
model_losses
from
model_utils
import
model_optimization
# Data.
from
model_utils
import
model_utils
from
model_utils
import
n_gram
from
models
import
evaluation_utils
from
models
import
rollout
np
.
set_printoptions
(
precision
=
3
)
np
.
set_printoptions
(
suppress
=
True
)
MODE_TRAIN
=
'TRAIN'
MODE_TRAIN_EVAL
=
'TRAIN_EVAL'
MODE_VALIDATION
=
'VALIDATION'
MODE_TEST
=
'TEST'
## Binary and setup FLAGS.
tf
.
app
.
flags
.
DEFINE_enum
(
'mode'
,
'TRAIN'
,
[
MODE_TRAIN
,
MODE_VALIDATION
,
MODE_TEST
,
MODE_TRAIN_EVAL
],
'What this binary will do.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'master'
,
'local'
,
"""Name of the TensorFlow master to use."""
)
tf
.
app
.
flags
.
DEFINE_string
(
'eval_master'
,
'local'
,
"""Name prefix of the Tensorflow eval master,
or "local"."""
)
tf
.
app
.
flags
.
DEFINE_integer
(
'task'
,
0
,
"""Task id of the replica running the training."""
)
tf
.
app
.
flags
.
DEFINE_integer
(
'ps_tasks'
,
0
,
"""Number of tasks in the ps job.
If 0 no ps job is used."""
)
## General FLAGS.
tf
.
app
.
flags
.
DEFINE_string
(
'hparams'
,
''
,
'Comma separated list of name=value hyperparameter pairs.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'batch_size'
,
20
,
'The batch size.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'vocab_size'
,
10000
,
'The vocabulary size.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'sequence_length'
,
20
,
'The sequence length.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'max_steps'
,
1000000
,
'Maximum number of steps to run.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'mask_strategy'
,
'random'
,
'Strategy for masking the words. Determine the '
'characterisitics of how the words are dropped out. One of '
"['contiguous', 'random']."
)
tf
.
app
.
flags
.
DEFINE_float
(
'is_present_rate'
,
0.5
,
'Percent of tokens present in the forward sequence.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'is_present_rate_decay'
,
None
,
'Decay rate for the '
'percent of words that are real (are present).'
)
tf
.
app
.
flags
.
DEFINE_string
(
'generator_model'
,
'seq2seq'
,
"Type of Generator model. One of ['rnn', 'seq2seq', 'seq2seq_zaremba',"
"'rnn_zaremba', 'rnn_nas', 'seq2seq_nas']"
)
tf
.
app
.
flags
.
DEFINE_string
(
'attention_option'
,
None
,
"Attention mechanism. One of [None, 'luong', 'bahdanau']"
)
tf
.
app
.
flags
.
DEFINE_string
(
'discriminator_model'
,
'bidirectional'
,
"Type of Discriminator model. One of ['cnn', 'rnn', 'bidirectional', "
"'rnn_zaremba', 'bidirectional_zaremba', 'rnn_nas', 'rnn_vd', 'seq2seq_vd']"
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'seq2seq_share_embedding'
,
False
,
'Whether to share the '
'embeddings between the encoder and decoder.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'dis_share_embedding'
,
False
,
'Whether to share the '
'embeddings between the generator and discriminator.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'dis_update_share_embedding'
,
False
,
'Whether the '
'discriminator should update the shared embedding.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'use_gen_mode'
,
False
,
'Use the mode of the generator '
'to produce samples.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'critic_update_dis_vars'
,
False
,
'Whether the critic '
'updates the discriminator variables.'
)
## Training FLAGS.
tf
.
app
.
flags
.
DEFINE_string
(
'gen_training_strategy'
,
'reinforce'
,
"Method for training the Generator. One of ['cross_entropy', 'reinforce']"
)
tf
.
app
.
flags
.
DEFINE_string
(
'generator_optimizer'
,
'adam'
,
"Type of Generator optimizer. One of ['sgd', 'adam']"
)
tf
.
app
.
flags
.
DEFINE_float
(
'grad_clipping'
,
10.
,
'Norm for gradient clipping.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'advantage_clipping'
,
5.
,
'Clipping for advantages.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'baseline_method'
,
None
,
"Approach for baseline. One of ['critic', 'dis_batch', 'ema', None]"
)
tf
.
app
.
flags
.
DEFINE_float
(
'perplexity_threshold'
,
15000
,
'Limit for perplexity before terminating job.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'zoneout_drop_prob'
,
0.1
,
'Probability for dropping parameter for zoneout.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'keep_prob'
,
0.5
,
'Probability for keeping parameter for dropout.'
)
## Logging and evaluation FLAGS.
tf
.
app
.
flags
.
DEFINE_integer
(
'print_every'
,
250
,
'Frequency to print and log the '
'outputs of the model.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'max_num_to_print'
,
5
,
'Number of samples to log/print.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'print_verbose'
,
False
,
'Whether to print in full.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'summaries_every'
,
100
,
'Frequency to compute summaries.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'eval_language_model'
,
False
,
'Whether to evaluate on '
'all words as in language modeling.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'eval_interval_secs'
,
60
,
'Delay for evaluating model.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'n_gram_eval'
,
4
,
"""The degree of the n-grams to use for evaluation."""
)
tf
.
app
.
flags
.
DEFINE_integer
(
'epoch_size_override'
,
None
,
'If an integer, this dictates the size of the epochs and will potentially '
'not iterate over all the data.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'eval_epoch_size_override'
,
None
,
'Number of evaluation steps.'
)
## Directories and checkpoints.
tf
.
app
.
flags
.
DEFINE_string
(
'base_directory'
,
'/tmp/maskGAN_v0.00'
,
'Base directory for the logging, events and graph.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'data_set'
,
'ptb'
,
'Data set to operate on. One of'
"['ptb', 'imdb']"
)
tf
.
app
.
flags
.
DEFINE_string
(
'data_dir'
,
'/tmp/data/ptb'
,
'Directory for the training data.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'language_model_ckpt_dir'
,
None
,
'Directory storing checkpoints to initialize the model. Pretrained models'
'are stored at /tmp/maskGAN/pretrained/'
)
tf
.
app
.
flags
.
DEFINE_string
(
'language_model_ckpt_dir_reversed'
,
None
,
'Directory storing checkpoints of reversed models to initialize the model.'
'Pretrained models stored at'
'are stored at /tmp/PTB/pretrained_reversed'
)
tf
.
app
.
flags
.
DEFINE_string
(
'maskgan_ckpt'
,
None
,
'Override which checkpoint file to use to restore the '
'model. A pretrained seq2seq_zaremba model is stored at '
'/tmp/maskGAN/pretrain/seq2seq_zaremba/train/model.ckpt-64912'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'wasserstein_objective'
,
False
,
'(DEPRECATED) Whether to use the WGAN training.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'num_rollouts'
,
1
,
'The number of rolled out predictions to make.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'c_lower'
,
-
0.01
,
'Lower bound for weights.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'c_upper'
,
0.01
,
'Upper bound for weights.'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
create_hparams
():
"""Create the hparams object for generic training hyperparameters."""
hparams
=
tf
.
contrib
.
training
.
HParams
(
gen_num_layers
=
2
,
dis_num_layers
=
2
,
gen_rnn_size
=
740
,
dis_rnn_size
=
740
,
gen_learning_rate
=
5e-4
,
dis_learning_rate
=
5e-3
,
critic_learning_rate
=
5e-3
,
dis_train_iterations
=
1
,
gen_learning_rate_decay
=
1.0
,
gen_full_learning_rate_steps
=
1e7
,
baseline_decay
=
0.999999
,
rl_discount_rate
=
0.9
,
gen_vd_keep_prob
=
0.5
,
dis_vd_keep_prob
=
0.5
,
dis_pretrain_learning_rate
=
5e-3
,
dis_num_filters
=
128
,
dis_hidden_dim
=
128
,
gen_nas_keep_prob_0
=
0.85
,
gen_nas_keep_prob_1
=
0.55
,
dis_nas_keep_prob_0
=
0.85
,
dis_nas_keep_prob_1
=
0.55
)
# Command line flags override any of the preceding hyperparameter values.
if
FLAGS
.
hparams
:
hparams
=
hparams
.
parse
(
FLAGS
.
hparams
)
return
hparams
def
create_MaskGAN
(
hparams
,
is_training
):
"""Create the MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
is_training: Boolean indicating operational mode (train/inference).
evaluated with a teacher forcing regime.
Return:
model: Namedtuple for specifying the MaskGAN.
"""
global_step
=
tf
.
Variable
(
0
,
name
=
'global_step'
,
trainable
=
False
)
new_learning_rate
=
tf
.
placeholder
(
tf
.
float32
,
[],
name
=
'new_learning_rate'
)
learning_rate
=
tf
.
Variable
(
0.0
,
name
=
'learning_rate'
,
trainable
=
False
)
learning_rate_update
=
tf
.
assign
(
learning_rate
,
new_learning_rate
)
new_rate
=
tf
.
placeholder
(
tf
.
float32
,
[],
name
=
'new_rate'
)
percent_real_var
=
tf
.
Variable
(
0.0
,
trainable
=
False
)
percent_real_update
=
tf
.
assign
(
percent_real_var
,
new_rate
)
## Placeholders.
inputs
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
targets
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
present
=
tf
.
placeholder
(
tf
.
bool
,
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
# TODO(adai): Placeholder for IMDB label.
## Real Sequence is the targets.
real_sequence
=
targets
## Fakse Sequence from the Generator.
# TODO(adai): Generator must have IMDB labels placeholder.
(
fake_sequence
,
fake_logits
,
fake_log_probs
,
fake_gen_initial_state
,
fake_gen_final_state
,
_
)
=
model_construction
.
create_generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
False
)
(
_
,
eval_logits
,
_
,
eval_initial_state
,
eval_final_state
,
_
)
=
model_construction
.
create_generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
False
,
is_validating
=
True
,
reuse
=
True
)
## Discriminator.
fake_predictions
=
model_construction
.
create_discriminator
(
hparams
,
fake_sequence
,
is_training
=
is_training
,
inputs
=
inputs
,
present
=
present
)
real_predictions
=
model_construction
.
create_discriminator
(
hparams
,
real_sequence
,
is_training
=
is_training
,
reuse
=
True
,
inputs
=
inputs
,
present
=
present
)
## Critic.
# The critic will be used to estimate the forward rewards to the Generator.
if
FLAGS
.
baseline_method
==
'critic'
:
est_state_values
=
model_construction
.
create_critic
(
hparams
,
fake_sequence
,
is_training
=
is_training
)
else
:
est_state_values
=
None
## Discriminator Loss.
[
dis_loss
,
dis_loss_fake
,
dis_loss_real
]
=
model_losses
.
create_dis_loss
(
fake_predictions
,
real_predictions
,
present
)
## Average log-perplexity for only missing words. However, to do this,
# the logits are still computed using teacher forcing, that is, the ground
# truth tokens are fed in at each time point to be valid.
avg_log_perplexity
=
model_losses
.
calculate_log_perplexity
(
eval_logits
,
targets
,
present
)
## Generator Objective.
# 1. Cross Entropy losses on missing tokens.
fake_cross_entropy_losses
=
model_losses
.
create_masked_cross_entropy_loss
(
targets
,
present
,
fake_logits
)
# 2. GAN REINFORCE losses.
[
fake_RL_loss
,
fake_log_probs
,
fake_rewards
,
fake_advantages
,
fake_baselines
,
fake_averages_op
,
critic_loss
,
cumulative_rewards
]
=
model_losses
.
calculate_reinforce_objective
(
hparams
,
fake_log_probs
,
fake_predictions
,
present
,
est_state_values
)
## Pre-training.
if
FLAGS
.
gen_pretrain_steps
:
raise
NotImplementedError
# # TODO(liamfedus): Rewrite this.
# fwd_cross_entropy_loss = tf.reduce_mean(fwd_cross_entropy_losses)
# gen_pretrain_op = model_optimization.create_gen_pretrain_op(
# hparams, fwd_cross_entropy_loss, global_step)
else
:
gen_pretrain_op
=
None
if
FLAGS
.
dis_pretrain_steps
:
dis_pretrain_op
=
model_optimization
.
create_dis_pretrain_op
(
hparams
,
dis_loss
,
global_step
)
else
:
dis_pretrain_op
=
None
## Generator Train Op.
# 1. Cross-Entropy.
if
FLAGS
.
gen_training_strategy
==
'cross_entropy'
:
gen_loss
=
tf
.
reduce_mean
(
fake_cross_entropy_losses
)
[
gen_train_op
,
gen_grads
,
gen_vars
]
=
model_optimization
.
create_gen_train_op
(
hparams
,
learning_rate
,
gen_loss
,
global_step
,
mode
=
'MINIMIZE'
)
# 2. GAN (REINFORCE)
elif
FLAGS
.
gen_training_strategy
==
'reinforce'
:
gen_loss
=
fake_RL_loss
[
gen_train_op
,
gen_grads
,
gen_vars
]
=
model_optimization
.
create_reinforce_gen_train_op
(
hparams
,
learning_rate
,
gen_loss
,
fake_averages_op
,
global_step
)
else
:
raise
NotImplementedError
## Discriminator Train Op.
dis_train_op
,
dis_grads
,
dis_vars
=
model_optimization
.
create_dis_train_op
(
hparams
,
dis_loss
,
global_step
)
## Critic Train Op.
if
critic_loss
is
not
None
:
[
critic_train_op
,
_
,
_
]
=
model_optimization
.
create_critic_train_op
(
hparams
,
critic_loss
,
global_step
)
dis_train_op
=
tf
.
group
(
dis_train_op
,
critic_train_op
)
## Summaries.
with
tf
.
name_scope
(
'general'
):
tf
.
summary
.
scalar
(
'percent_real'
,
percent_real_var
)
tf
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
with
tf
.
name_scope
(
'generator_objectives'
):
tf
.
summary
.
scalar
(
'gen_objective'
,
tf
.
reduce_mean
(
gen_loss
))
tf
.
summary
.
scalar
(
'gen_loss_cross_entropy'
,
tf
.
reduce_mean
(
fake_cross_entropy_losses
))
with
tf
.
name_scope
(
'REINFORCE'
):
with
tf
.
name_scope
(
'objective'
):
tf
.
summary
.
scalar
(
'fake_RL_loss'
,
tf
.
reduce_mean
(
fake_RL_loss
))
with
tf
.
name_scope
(
'rewards'
):
helper
.
variable_summaries
(
cumulative_rewards
,
'rewards'
)
with
tf
.
name_scope
(
'advantages'
):
helper
.
variable_summaries
(
fake_advantages
,
'advantages'
)
with
tf
.
name_scope
(
'baselines'
):
helper
.
variable_summaries
(
fake_baselines
,
'baselines'
)
with
tf
.
name_scope
(
'log_probs'
):
helper
.
variable_summaries
(
fake_log_probs
,
'log_probs'
)
with
tf
.
name_scope
(
'discriminator_losses'
):
tf
.
summary
.
scalar
(
'dis_loss'
,
dis_loss
)
tf
.
summary
.
scalar
(
'dis_loss_fake_sequence'
,
dis_loss_fake
)
tf
.
summary
.
scalar
(
'dis_loss_prob_fake_sequence'
,
tf
.
exp
(
-
dis_loss_fake
))
tf
.
summary
.
scalar
(
'dis_loss_real_sequence'
,
dis_loss_real
)
tf
.
summary
.
scalar
(
'dis_loss_prob_real_sequence'
,
tf
.
exp
(
-
dis_loss_real
))
if
critic_loss
is
not
None
:
with
tf
.
name_scope
(
'critic_losses'
):
tf
.
summary
.
scalar
(
'critic_loss'
,
critic_loss
)
with
tf
.
name_scope
(
'logits'
):
helper
.
variable_summaries
(
fake_logits
,
'fake_logits'
)
for
v
,
g
in
zip
(
gen_vars
,
gen_grads
):
helper
.
variable_summaries
(
v
,
v
.
op
.
name
)
helper
.
variable_summaries
(
g
,
'grad/'
+
v
.
op
.
name
)
for
v
,
g
in
zip
(
dis_vars
,
dis_grads
):
helper
.
variable_summaries
(
v
,
v
.
op
.
name
)
helper
.
variable_summaries
(
g
,
'grad/'
+
v
.
op
.
name
)
merge_summaries_op
=
tf
.
summary
.
merge_all
()
text_summary_placeholder
=
tf
.
placeholder
(
tf
.
string
)
text_summary_op
=
tf
.
summary
.
text
(
'Samples'
,
text_summary_placeholder
)
# Model saver.
saver
=
tf
.
train
.
Saver
(
keep_checkpoint_every_n_hours
=
1
,
max_to_keep
=
5
)
# Named tuple that captures elements of the MaskGAN model.
Model
=
collections
.
namedtuple
(
'Model'
,
[
'inputs'
,
'targets'
,
'present'
,
'percent_real_update'
,
'new_rate'
,
'fake_sequence'
,
'fake_logits'
,
'fake_rewards'
,
'fake_baselines'
,
'fake_advantages'
,
'fake_log_probs'
,
'fake_predictions'
,
'real_predictions'
,
'fake_cross_entropy_losses'
,
'fake_gen_initial_state'
,
'fake_gen_final_state'
,
'eval_initial_state'
,
'eval_final_state'
,
'avg_log_perplexity'
,
'dis_loss'
,
'gen_loss'
,
'critic_loss'
,
'cumulative_rewards'
,
'dis_train_op'
,
'gen_train_op'
,
'gen_pretrain_op'
,
'dis_pretrain_op'
,
'merge_summaries_op'
,
'global_step'
,
'new_learning_rate'
,
'learning_rate_update'
,
'saver'
,
'text_summary_op'
,
'text_summary_placeholder'
])
model
=
Model
(
inputs
,
targets
,
present
,
percent_real_update
,
new_rate
,
fake_sequence
,
fake_logits
,
fake_rewards
,
fake_baselines
,
fake_advantages
,
fake_log_probs
,
fake_predictions
,
real_predictions
,
fake_cross_entropy_losses
,
fake_gen_initial_state
,
fake_gen_final_state
,
eval_initial_state
,
eval_final_state
,
avg_log_perplexity
,
dis_loss
,
gen_loss
,
critic_loss
,
cumulative_rewards
,
dis_train_op
,
gen_train_op
,
gen_pretrain_op
,
dis_pretrain_op
,
merge_summaries_op
,
global_step
,
new_learning_rate
,
learning_rate_update
,
saver
,
text_summary_op
,
text_summary_placeholder
)
return
model
def
compute_geometric_average
(
percent_captured
):
"""Compute the geometric average of the n-gram metrics."""
res
=
1.
for
_
,
n_gram_percent
in
percent_captured
.
iteritems
():
res
*=
n_gram_percent
return
np
.
power
(
res
,
1.
/
float
(
len
(
percent_captured
)))
def
compute_arithmetic_average
(
percent_captured
):
"""Compute the arithmetic average of the n-gram metrics."""
N
=
len
(
percent_captured
)
res
=
0.
for
_
,
n_gram_percent
in
percent_captured
.
iteritems
():
res
+=
n_gram_percent
return
res
/
float
(
N
)
def
get_iterator
(
data
):
"""Return the data iterator."""
if
FLAGS
.
data_set
==
'ptb'
:
iterator
=
ptb_loader
.
ptb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
FLAGS
.
epoch_size_override
)
elif
FLAGS
.
data_set
==
'imdb'
:
iterator
=
imdb_loader
.
imdb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
)
return
iterator
def
train_model
(
hparams
,
data
,
log_dir
,
log
,
id_to_word
,
data_ngram_counts
):
"""Train model.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
log_dir: Directory to save checkpoints.
log: Readable log for the experiment.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
"""
print
(
'Training model.'
)
tf
.
logging
.
info
(
'Training model.'
)
# Boolean indicating operational mode.
is_training
=
True
# Write all the information to the logs.
log
.
write
(
'hparams
\n
'
)
log
.
write
(
str
(
hparams
))
log
.
flush
()
is_chief
=
FLAGS
.
task
==
0
with
tf
.
Graph
().
as_default
():
with
tf
.
device
(
tf
.
ReplicaDeviceSetter
(
FLAGS
.
ps_tasks
)):
container_name
=
''
with
tf
.
container
(
container_name
):
# Construct the model.
if
FLAGS
.
num_rollouts
==
1
:
model
=
create_MaskGAN
(
hparams
,
is_training
)
elif
FLAGS
.
num_rollouts
>
1
:
model
=
rollout
.
create_rollout_MaskGAN
(
hparams
,
is_training
)
else
:
raise
ValueError
print
(
'
\n
Trainable Variables in Graph:'
)
for
v
in
tf
.
trainable_variables
():
print
(
v
)
## Retrieve the initial savers.
init_savers
=
model_utils
.
retrieve_init_savers
(
hparams
)
## Initial saver function to supervisor.
init_fn
=
partial
(
model_utils
.
init_fn
,
init_savers
)
# Create the supervisor. It will take care of initialization,
# summaries, checkpoints, and recovery.
sv
=
tf
.
Supervisor
(
logdir
=
log_dir
,
is_chief
=
is_chief
,
saver
=
model
.
saver
,
global_step
=
model
.
global_step
,
save_model_secs
=
60
,
recovery_wait_secs
=
30
,
summary_op
=
None
,
init_fn
=
init_fn
)
# Get an initialized, and possibly recovered session. Launch the
# services: Checkpointing, Summaries, step counting.
#
# When multiple replicas of this program are running the services are
# only launched by the 'chief' replica.
with
sv
.
managed_session
(
FLAGS
.
master
)
as
sess
:
## Pretrain the generator.
if
FLAGS
.
gen_pretrain_steps
:
pretrain_mask_gan
.
pretrain_generator
(
sv
,
sess
,
model
,
data
,
log
,
id_to_word
,
data_ngram_counts
,
is_chief
)
## Pretrain the discriminator.
if
FLAGS
.
dis_pretrain_steps
:
pretrain_mask_gan
.
pretrain_discriminator
(
sv
,
sess
,
model
,
data
,
log
,
id_to_word
,
data_ngram_counts
,
is_chief
)
# Initial indicators for printing and summarizing.
print_step_division
=
-
1
summary_step_division
=
-
1
# Run iterative computation in a loop.
while
not
sv
.
ShouldStop
():
is_present_rate
=
FLAGS
.
is_present_rate
if
FLAGS
.
is_present_rate_decay
is
not
None
:
is_present_rate
*=
(
1.
-
FLAGS
.
is_present_rate_decay
)
model_utils
.
assign_percent_real
(
sess
,
model
.
percent_real_update
,
model
.
new_rate
,
is_present_rate
)
# GAN training.
avg_epoch_gen_loss
,
avg_epoch_dis_loss
=
[],
[]
cumulative_costs
=
0.
gen_iters
=
0
# Generator and Discriminator statefulness initial evaluation.
# TODO(liamfedus): Throughout the code I am implicitly assuming
# that the Generator and Discriminator are equal sized.
[
gen_initial_state_eval
,
fake_gen_initial_state_eval
]
=
sess
.
run
(
[
model
.
eval_initial_state
,
model
.
fake_gen_initial_state
])
dis_initial_state_eval
=
fake_gen_initial_state_eval
# Save zeros state to reset later.
zeros_state
=
fake_gen_initial_state_eval
## Offset Discriminator.
if
FLAGS
.
ps_tasks
==
0
:
dis_offset
=
1
else
:
dis_offset
=
FLAGS
.
task
*
1000
+
1
dis_iterator
=
get_iterator
(
data
)
for
i
in
range
(
dis_offset
):
try
:
dis_x
,
dis_y
,
_
=
next
(
dis_iterator
)
except
StopIteration
:
dis_iterator
=
get_iterator
(
data
)
dis_initial_state_eval
=
zeros_state
dis_x
,
dis_y
,
_
=
next
(
dis_iterator
)
p
=
model_utils
.
generate_mask
()
# Construct the train feed.
train_feed
=
{
model
.
inputs
:
dis_x
,
model
.
targets
:
dis_y
,
model
.
present
:
p
}
if
FLAGS
.
data_set
==
'ptb'
:
# Statefulness of the Generator being used for Discriminator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
fake_gen_initial_state
):
train_feed
[
c
]
=
dis_initial_state_eval
[
i
].
c
train_feed
[
h
]
=
dis_initial_state_eval
[
i
].
h
# Determine the state had the Generator run over real data. We
# use this state for the Discriminator.
[
dis_initial_state_eval
]
=
sess
.
run
(
[
model
.
fake_gen_final_state
],
train_feed
)
## Training loop.
iterator
=
get_iterator
(
data
)
gen_initial_state_eval
=
zeros_state
if
FLAGS
.
ps_tasks
>
0
:
gen_offset
=
FLAGS
.
task
*
1000
+
1
for
i
in
range
(
gen_offset
):
try
:
next
(
iterator
)
except
StopIteration
:
dis_iterator
=
get_iterator
(
data
)
dis_initial_state_eval
=
zeros_state
next
(
dis_iterator
)
for
x
,
y
,
_
in
iterator
:
for
_
in
xrange
(
hparams
.
dis_train_iterations
):
try
:
dis_x
,
dis_y
,
_
=
next
(
dis_iterator
)
except
StopIteration
:
dis_iterator
=
get_iterator
(
data
)
dis_initial_state_eval
=
zeros_state
dis_x
,
dis_y
,
_
=
next
(
dis_iterator
)
if
FLAGS
.
data_set
==
'ptb'
:
[
dis_initial_state_eval
]
=
sess
.
run
(
[
model
.
fake_gen_initial_state
])
p
=
model_utils
.
generate_mask
()
# Construct the train feed.
train_feed
=
{
model
.
inputs
:
dis_x
,
model
.
targets
:
dis_y
,
model
.
present
:
p
}
# Statefulness for the Discriminator.
if
FLAGS
.
data_set
==
'ptb'
:
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
fake_gen_initial_state
):
train_feed
[
c
]
=
dis_initial_state_eval
[
i
].
c
train_feed
[
h
]
=
dis_initial_state_eval
[
i
].
h
_
,
dis_loss_eval
,
step
=
sess
.
run
(
[
model
.
dis_train_op
,
model
.
dis_loss
,
model
.
global_step
],
feed_dict
=
train_feed
)
# Determine the state had the Generator run over real data.
# Use this state for the Discriminator.
[
dis_initial_state_eval
]
=
sess
.
run
(
[
model
.
fake_gen_final_state
],
train_feed
)
# Randomly mask out tokens.
p
=
model_utils
.
generate_mask
()
# Construct the train feed.
train_feed
=
{
model
.
inputs
:
x
,
model
.
targets
:
y
,
model
.
present
:
p
}
# Statefulness for Generator.
if
FLAGS
.
data_set
==
'ptb'
:
tf
.
logging
.
info
(
'Generator is stateful.'
)
print
(
'Generator is stateful.'
)
# Statefulness for *evaluation* Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
eval_initial_state
):
train_feed
[
c
]
=
gen_initial_state_eval
[
i
].
c
train_feed
[
h
]
=
gen_initial_state_eval
[
i
].
h
# Statefulness for Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
fake_gen_initial_state
):
train_feed
[
c
]
=
fake_gen_initial_state_eval
[
i
].
c
train_feed
[
h
]
=
fake_gen_initial_state_eval
[
i
].
h
# Determine whether to decay learning rate.
lr_decay
=
hparams
.
gen_learning_rate_decay
**
max
(
step
+
1
-
hparams
.
gen_full_learning_rate_steps
,
0.0
)
# Assign learning rate.
gen_learning_rate
=
hparams
.
gen_learning_rate
*
lr_decay
model_utils
.
assign_learning_rate
(
sess
,
model
.
learning_rate_update
,
model
.
new_learning_rate
,
gen_learning_rate
)
[
_
,
gen_loss_eval
,
gen_log_perplexity_eval
,
step
]
=
sess
.
run
(
[
model
.
gen_train_op
,
model
.
gen_loss
,
model
.
avg_log_perplexity
,
model
.
global_step
],
feed_dict
=
train_feed
)
cumulative_costs
+=
gen_log_perplexity_eval
gen_iters
+=
1
# Determine the state had the Generator run over real data.
[
gen_initial_state_eval
,
fake_gen_initial_state_eval
]
=
sess
.
run
(
[
model
.
eval_final_state
,
model
.
fake_gen_final_state
],
train_feed
)
avg_epoch_dis_loss
.
append
(
dis_loss_eval
)
avg_epoch_gen_loss
.
append
(
gen_loss_eval
)
## Summaries.
# Calulate rolling perplexity.
perplexity
=
np
.
exp
(
cumulative_costs
/
gen_iters
)
if
is_chief
and
(
step
/
FLAGS
.
summaries_every
>
summary_step_division
):
summary_step_division
=
step
/
FLAGS
.
summaries_every
# Confirm perplexity is not infinite.
if
(
not
np
.
isfinite
(
perplexity
)
or
perplexity
>=
FLAGS
.
perplexity_threshold
):
print
(
'Training raising FloatingPoinError.'
)
raise
FloatingPointError
(
'Training infinite perplexity: %.3f'
%
perplexity
)
# Graph summaries.
summary_str
=
sess
.
run
(
model
.
merge_summaries_op
,
feed_dict
=
train_feed
)
sv
.
SummaryComputed
(
sess
,
summary_str
)
# Summary: n-gram
avg_percent_captured
=
{
'2'
:
0.
,
'3'
:
0.
,
'4'
:
0.
}
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
batch_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
train_feed
,
data_ngram_count
,
int
(
n
))
summary_percent_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/%s-grams_percent_correct'
%
n
,
simple_value
=
batch_percent_captured
)
])
sv
.
SummaryComputed
(
sess
,
summary_percent_str
,
global_step
=
step
)
# Summary: geometric_avg
geometric_avg
=
compute_geometric_average
(
avg_percent_captured
)
summary_geometric_avg_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/geometric_avg'
,
simple_value
=
geometric_avg
)
])
sv
.
SummaryComputed
(
sess
,
summary_geometric_avg_str
,
global_step
=
step
)
# Summary: arithmetic_avg
arithmetic_avg
=
compute_arithmetic_average
(
avg_percent_captured
)
summary_arithmetic_avg_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/arithmetic_avg'
,
simple_value
=
arithmetic_avg
)
])
sv
.
SummaryComputed
(
sess
,
summary_arithmetic_avg_str
,
global_step
=
step
)
# Summary: perplexity
summary_perplexity_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/perplexity'
,
simple_value
=
perplexity
)
])
sv
.
SummaryComputed
(
sess
,
summary_perplexity_str
,
global_step
=
step
)
## Printing and logging
if
is_chief
and
(
step
/
FLAGS
.
print_every
>
print_step_division
):
print_step_division
=
(
step
/
FLAGS
.
print_every
)
print
(
'global_step: %d'
%
step
)
print
(
' perplexity: %.3f'
%
perplexity
)
print
(
' gen_learning_rate: %.6f'
%
gen_learning_rate
)
log
.
write
(
'global_step: %d
\n
'
%
step
)
log
.
write
(
' perplexity: %.3f
\n
'
%
perplexity
)
log
.
write
(
' gen_learning_rate: %.6f'
%
gen_learning_rate
)
# Average percent captured for each of the n-grams.
avg_percent_captured
=
{
'2'
:
0.
,
'3'
:
0.
,
'4'
:
0.
}
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
batch_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
train_feed
,
data_ngram_count
,
int
(
n
))
avg_percent_captured
[
n
]
=
batch_percent_captured
print
(
' percent of %s-grams captured: %.3f.'
%
(
n
,
batch_percent_captured
))
log
.
write
(
' percent of %s-grams captured: %.3f.
\n
'
%
(
n
,
batch_percent_captured
))
geometric_avg
=
compute_geometric_average
(
avg_percent_captured
)
print
(
' geometric_avg: %.3f.'
%
geometric_avg
)
log
.
write
(
' geometric_avg: %.3f.'
%
geometric_avg
)
arithmetic_avg
=
compute_arithmetic_average
(
avg_percent_captured
)
print
(
' arithmetic_avg: %.3f.'
%
arithmetic_avg
)
log
.
write
(
' arithmetic_avg: %.3f.'
%
arithmetic_avg
)
evaluation_utils
.
print_and_log_losses
(
log
,
step
,
is_present_rate
,
avg_epoch_dis_loss
,
avg_epoch_gen_loss
)
if
FLAGS
.
gen_training_strategy
==
'reinforce'
:
evaluation_utils
.
generate_RL_logs
(
sess
,
model
,
log
,
id_to_word
,
train_feed
)
else
:
evaluation_utils
.
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
train_feed
)
log
.
flush
()
log
.
close
()
def
evaluate_once
(
data
,
sv
,
model
,
sess
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
,
eval_saver
):
"""Evaluate model for a number of steps.
Args:
data: Dataset.
sv: Supervisor.
model: The GAN model we have just built.
sess: A session to use.
train_dir: Path to a directory containing checkpoints.
log: Evaluation log for evaluation.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
eval_saver: Evaluation saver.r.
"""
tf
.
logging
.
info
(
'Evaluate Once.'
)
# Load the last model checkpoint, or initialize the graph.
model_save_path
=
tf
.
latest_checkpoint
(
train_dir
)
if
not
model_save_path
:
tf
.
logging
.
warning
(
'No checkpoint yet in: %s'
,
train_dir
)
return
tf
.
logging
.
info
(
'Starting eval of: %s'
%
model_save_path
)
tf
.
logging
.
info
(
'Only restoring trainable variables.'
)
eval_saver
.
restore
(
sess
,
model_save_path
)
# Run the requested number of evaluation steps
avg_epoch_gen_loss
,
avg_epoch_dis_loss
=
[],
[]
cumulative_costs
=
0.
# Average percent captured for each of the n-grams.
avg_percent_captured
=
{
'2'
:
0.
,
'3'
:
0.
,
'4'
:
0.
}
# Set a random seed to keep fixed mask.
np
.
random
.
seed
(
0
)
gen_iters
=
0
# Generator statefulness over the epoch.
# TODO(liamfedus): Check this.
[
gen_initial_state_eval
,
fake_gen_initial_state_eval
]
=
sess
.
run
(
[
model
.
eval_initial_state
,
model
.
fake_gen_initial_state
])
if
FLAGS
.
eval_language_model
:
is_present_rate
=
0.
tf
.
logging
.
info
(
'Overriding is_present_rate=0. for evaluation.'
)
print
(
'Overriding is_present_rate=0. for evaluation.'
)
iterator
=
get_iterator
(
data
)
for
x
,
y
,
_
in
iterator
:
if
FLAGS
.
eval_language_model
:
is_present_rate
=
0.
else
:
is_present_rate
=
FLAGS
.
is_present_rate
tf
.
logging
.
info
(
'Evaluating on is_present_rate=%.3f.'
%
is_present_rate
)
model_utils
.
assign_percent_real
(
sess
,
model
.
percent_real_update
,
model
.
new_rate
,
is_present_rate
)
# Randomly mask out tokens.
p
=
model_utils
.
generate_mask
()
eval_feed
=
{
model
.
inputs
:
x
,
model
.
targets
:
y
,
model
.
present
:
p
}
if
FLAGS
.
data_set
==
'ptb'
:
# Statefulness for *evaluation* Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
eval_initial_state
):
eval_feed
[
c
]
=
gen_initial_state_eval
[
i
].
c
eval_feed
[
h
]
=
gen_initial_state_eval
[
i
].
h
# Statefulness for the Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
fake_gen_initial_state
):
eval_feed
[
c
]
=
fake_gen_initial_state_eval
[
i
].
c
eval_feed
[
h
]
=
fake_gen_initial_state_eval
[
i
].
h
[
gen_log_perplexity_eval
,
dis_loss_eval
,
gen_loss_eval
,
gen_initial_state_eval
,
fake_gen_initial_state_eval
,
step
]
=
sess
.
run
(
[
model
.
avg_log_perplexity
,
model
.
dis_loss
,
model
.
gen_loss
,
model
.
eval_final_state
,
model
.
fake_gen_final_state
,
model
.
global_step
],
feed_dict
=
eval_feed
)
for
n
,
data_ngram_count
in
data_ngram_counts
.
iteritems
():
batch_percent_captured
=
evaluation_utils
.
sequence_ngram_evaluation
(
sess
,
model
.
fake_sequence
,
log
,
eval_feed
,
data_ngram_count
,
int
(
n
))
avg_percent_captured
[
n
]
+=
batch_percent_captured
cumulative_costs
+=
gen_log_perplexity_eval
avg_epoch_dis_loss
.
append
(
dis_loss_eval
)
avg_epoch_gen_loss
.
append
(
gen_loss_eval
)
gen_iters
+=
1
# Calulate rolling metrics.
perplexity
=
np
.
exp
(
cumulative_costs
/
gen_iters
)
for
n
,
_
in
avg_percent_captured
.
iteritems
():
avg_percent_captured
[
n
]
/=
gen_iters
# Confirm perplexity is not infinite.
if
not
np
.
isfinite
(
perplexity
)
or
perplexity
>=
FLAGS
.
perplexity_threshold
:
print
(
'Evaluation raising FloatingPointError.'
)
raise
FloatingPointError
(
'Evaluation infinite perplexity: %.3f'
%
perplexity
)
## Printing and logging.
evaluation_utils
.
print_and_log_losses
(
log
,
step
,
is_present_rate
,
avg_epoch_dis_loss
,
avg_epoch_gen_loss
)
print
(
' perplexity: %.3f'
%
perplexity
)
log
.
write
(
' perplexity: %.3f
\n
'
%
perplexity
)
for
n
,
n_gram_percent
in
avg_percent_captured
.
iteritems
():
n
=
int
(
n
)
print
(
' percent of %d-grams captured: %.3f.'
%
(
n
,
n_gram_percent
))
log
.
write
(
' percent of %d-grams captured: %.3f.
\n
'
%
(
n
,
n_gram_percent
))
samples
=
evaluation_utils
.
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
eval_feed
)
## Summaries.
summary_str
=
sess
.
run
(
model
.
merge_summaries_op
,
feed_dict
=
eval_feed
)
sv
.
SummaryComputed
(
sess
,
summary_str
)
# Summary: text
summary_str
=
sess
.
run
(
model
.
text_summary_op
,
{
model
.
text_summary_placeholder
:
'
\n\n
'
.
join
(
samples
)})
sv
.
SummaryComputed
(
sess
,
summary_str
,
global_step
=
step
)
# Summary: n-gram
for
n
,
n_gram_percent
in
avg_percent_captured
.
iteritems
():
n
=
int
(
n
)
summary_percent_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/%d-grams_percent_correct'
%
n
,
simple_value
=
n_gram_percent
)
])
sv
.
SummaryComputed
(
sess
,
summary_percent_str
,
global_step
=
step
)
# Summary: geometric_avg
geometric_avg
=
compute_geometric_average
(
avg_percent_captured
)
summary_geometric_avg_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/geometric_avg'
,
simple_value
=
geometric_avg
)
])
sv
.
SummaryComputed
(
sess
,
summary_geometric_avg_str
,
global_step
=
step
)
# Summary: arithmetic_avg
arithmetic_avg
=
compute_arithmetic_average
(
avg_percent_captured
)
summary_arithmetic_avg_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/arithmetic_avg'
,
simple_value
=
arithmetic_avg
)
])
sv
.
SummaryComputed
(
sess
,
summary_arithmetic_avg_str
,
global_step
=
step
)
# Summary: perplexity
summary_perplexity_str
=
tf
.
Summary
(
value
=
[
tf
.
Summary
.
Value
(
tag
=
'general/perplexity'
,
simple_value
=
perplexity
)
])
sv
.
SummaryComputed
(
sess
,
summary_perplexity_str
,
global_step
=
step
)
def
evaluate_model
(
hparams
,
data
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
):
"""Evaluate MaskGAN model.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
train_dir: Path to a directory containing checkpoints.
id_to_word: Dictionary of indices to words.
data_ngram_counts: Dictionary of hashed(n-gram tuples) to counts in the
data_set.
"""
tf
.
logging
.
error
(
'Evaluate model.'
)
# Boolean indicating operational mode.
is_training
=
False
if
FLAGS
.
mode
==
MODE_VALIDATION
:
logdir
=
FLAGS
.
base_directory
+
'/validation'
elif
FLAGS
.
mode
==
MODE_TRAIN_EVAL
:
logdir
=
FLAGS
.
base_directory
+
'/train_eval'
elif
FLAGS
.
mode
==
MODE_TEST
:
logdir
=
FLAGS
.
base_directory
+
'/test'
else
:
raise
NotImplementedError
# Wait for a checkpoint to exist.
print
(
train_dir
)
print
(
tf
.
train
.
latest_checkpoint
(
train_dir
))
while
not
tf
.
train
.
latest_checkpoint
(
train_dir
):
tf
.
logging
.
error
(
'Waiting for checkpoint...'
)
print
(
'Waiting for checkpoint...'
)
time
.
sleep
(
10
)
with
tf
.
Graph
().
as_default
():
# Use a separate container for each trial
container_name
=
''
with
tf
.
container
(
container_name
):
# Construct the model.
if
FLAGS
.
num_rollouts
==
1
:
model
=
create_MaskGAN
(
hparams
,
is_training
)
elif
FLAGS
.
num_rollouts
>
1
:
model
=
rollout
.
create_rollout_MaskGAN
(
hparams
,
is_training
)
else
:
raise
ValueError
# Create the supervisor. It will take care of initialization, summaries,
# checkpoints, and recovery. We only pass the trainable variables
# to load since things like baselines keep batch_size which may not
# match between training and evaluation.
evaluation_variables
=
tf
.
trainable_variables
()
evaluation_variables
.
append
(
model
.
global_step
)
eval_saver
=
tf
.
train
.
Saver
(
var_list
=
evaluation_variables
)
sv
=
tf
.
Supervisor
(
logdir
=
logdir
)
sess
=
sv
.
PrepareSession
(
FLAGS
.
eval_master
,
start_standard_services
=
False
)
tf
.
logging
.
info
(
'Before sv.Loop.'
)
sv
.
Loop
(
FLAGS
.
eval_interval_secs
,
evaluate_once
,
(
data
,
sv
,
model
,
sess
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
,
eval_saver
))
sv
.
WaitForStop
()
tf
.
logging
.
info
(
'sv.Stop().'
)
sv
.
Stop
()
def
main
(
_
):
hparams
=
create_hparams
()
train_dir
=
FLAGS
.
base_directory
+
'/train'
# Load data set.
if
FLAGS
.
data_set
==
'ptb'
:
raw_data
=
ptb_loader
.
ptb_raw_data
(
FLAGS
.
data_dir
)
train_data
,
valid_data
,
test_data
,
_
=
raw_data
valid_data_flat
=
valid_data
elif
FLAGS
.
data_set
==
'imdb'
:
raw_data
=
imdb_loader
.
imdb_raw_data
(
FLAGS
.
data_dir
)
# TODO(liamfedus): Get an IMDB test partition.
train_data
,
valid_data
=
raw_data
valid_data_flat
=
[
word
for
review
in
valid_data
for
word
in
review
]
else
:
raise
NotImplementedError
if
FLAGS
.
mode
==
MODE_TRAIN
or
FLAGS
.
mode
==
MODE_TRAIN_EVAL
:
data_set
=
train_data
elif
FLAGS
.
mode
==
MODE_VALIDATION
:
data_set
=
valid_data
elif
FLAGS
.
mode
==
MODE_TEST
:
data_set
=
test_data
else
:
raise
NotImplementedError
# Dictionary and reverse dictionry.
if
FLAGS
.
data_set
==
'ptb'
:
word_to_id
=
ptb_loader
.
build_vocab
(
os
.
path
.
join
(
FLAGS
.
data_dir
,
'ptb.train.txt'
))
elif
FLAGS
.
data_set
==
'imdb'
:
word_to_id
=
imdb_loader
.
build_vocab
(
os
.
path
.
join
(
FLAGS
.
data_dir
,
'vocab.txt'
))
id_to_word
=
{
v
:
k
for
k
,
v
in
word_to_id
.
iteritems
()}
# Dictionary of Training Set n-gram counts.
bigram_tuples
=
n_gram
.
find_all_ngrams
(
valid_data_flat
,
n
=
2
)
trigram_tuples
=
n_gram
.
find_all_ngrams
(
valid_data_flat
,
n
=
3
)
fourgram_tuples
=
n_gram
.
find_all_ngrams
(
valid_data_flat
,
n
=
4
)
bigram_counts
=
n_gram
.
construct_ngrams_dict
(
bigram_tuples
)
trigram_counts
=
n_gram
.
construct_ngrams_dict
(
trigram_tuples
)
fourgram_counts
=
n_gram
.
construct_ngrams_dict
(
fourgram_tuples
)
print
(
'Unique %d-grams: %d'
%
(
2
,
len
(
bigram_counts
)))
print
(
'Unique %d-grams: %d'
%
(
3
,
len
(
trigram_counts
)))
print
(
'Unique %d-grams: %d'
%
(
4
,
len
(
fourgram_counts
)))
data_ngram_counts
=
{
'2'
:
bigram_counts
,
'3'
:
trigram_counts
,
'4'
:
fourgram_counts
}
# TODO(liamfedus): This was necessary because there was a problem with our
# originally trained IMDB models. The EOS_INDEX was off by one, which means,
# two words were mapping to index 86933. The presence of '</s>' is going
# to throw and out of vocabulary error.
FLAGS
.
vocab_size
=
len
(
id_to_word
)
print
(
'Vocab size: %d'
%
FLAGS
.
vocab_size
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
base_directory
)
if
FLAGS
.
mode
==
MODE_TRAIN
:
log
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
base_directory
,
'train-log.txt'
),
mode
=
'w'
)
elif
FLAGS
.
mode
==
MODE_VALIDATION
:
log
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
base_directory
,
'validation-log.txt'
),
mode
=
'w'
)
elif
FLAGS
.
mode
==
MODE_TRAIN_EVAL
:
log
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
base_directory
,
'train_eval-log.txt'
),
mode
=
'w'
)
else
:
log
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
base_directory
,
'test-log.txt'
),
mode
=
'w'
)
if
FLAGS
.
mode
==
MODE_TRAIN
:
train_model
(
hparams
,
data_set
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
)
elif
FLAGS
.
mode
==
MODE_VALIDATION
:
evaluate_model
(
hparams
,
data_set
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
)
elif
FLAGS
.
mode
==
MODE_TRAIN_EVAL
:
evaluate_model
(
hparams
,
data_set
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
)
elif
FLAGS
.
mode
==
MODE_TEST
:
evaluate_model
(
hparams
,
data_set
,
train_dir
,
log
,
id_to_word
,
data_ngram_counts
)
else
:
raise
NotImplementedError
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment