Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7d16fc45
Unverified
Commit
7d16fc45
authored
Feb 27, 2018
by
Andrew M Dai
Committed by
GitHub
Feb 27, 2018
Browse files
Merge pull request #3486 from a-dai/master
Added new MaskGAN model.
parents
c5c6eaf2
87fae3f7
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3410 additions
and
0 deletions
+3410
-0
CODEOWNERS
CODEOWNERS
+1
-0
research/README.md
research/README.md
+1
-0
research/maskgan/README.md
research/maskgan/README.md
+90
-0
research/maskgan/data/__init__.py
research/maskgan/data/__init__.py
+0
-0
research/maskgan/data/imdb_loader.py
research/maskgan/data/imdb_loader.py
+136
-0
research/maskgan/data/ptb_loader.py
research/maskgan/data/ptb_loader.py
+123
-0
research/maskgan/generate_samples.py
research/maskgan/generate_samples.py
+281
-0
research/maskgan/losses/__init__.py
research/maskgan/losses/__init__.py
+0
-0
research/maskgan/losses/losses.py
research/maskgan/losses/losses.py
+186
-0
research/maskgan/model_utils/__init__.py
research/maskgan/model_utils/__init__.py
+0
-0
research/maskgan/model_utils/helper.py
research/maskgan/model_utils/helper.py
+157
-0
research/maskgan/model_utils/model_construction.py
research/maskgan/model_utils/model_construction.py
+234
-0
research/maskgan/model_utils/model_losses.py
research/maskgan/model_utils/model_losses.py
+327
-0
research/maskgan/model_utils/model_optimization.py
research/maskgan/model_utils/model_optimization.py
+194
-0
research/maskgan/model_utils/model_utils.py
research/maskgan/model_utils/model_utils.py
+291
-0
research/maskgan/model_utils/n_gram.py
research/maskgan/model_utils/n_gram.py
+64
-0
research/maskgan/model_utils/variable_mapping.py
research/maskgan/model_utils/variable_mapping.py
+773
-0
research/maskgan/models/__init__.py
research/maskgan/models/__init__.py
+0
-0
research/maskgan/models/attention_utils.py
research/maskgan/models/attention_utils.py
+477
-0
research/maskgan/models/bidirectional.py
research/maskgan/models/bidirectional.py
+75
-0
No files found.
CODEOWNERS
View file @
7d16fc45
...
...
@@ -18,6 +18,7 @@
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
/research/lfads/ @jazcollins @susillo
/research/lm_1b/ @oriolvinyals @panyx0718
/research/maskgan/ @a-dai
/research/namignizer/ @knathanieltucker
/research/neural_gpu/ @lukaszkaiser
/research/neural_programmer/ @arvind2505
...
...
research/README.md
View file @
7d16fc45
...
...
@@ -38,6 +38,7 @@ installation](https://www.tensorflow.org/install).
-
[
lfads
](
lfads
)
: sequential variational autoencoder for analyzing
neuroscience data.
-
[
lm_1b
](
lm_1b
)
: language modeling on the one billion word benchmark.
-
[
maskgan
](
maskgan
)
: text generation with GANs.
-
[
namignizer
](
namignizer
)
: recognize and generate names.
-
[
neural_gpu
](
neural_gpu
)
: highly parallel neural computer.
-
[
neural_programmer
](
neural_programmer
)
: neural network augmented with logic
...
...
research/maskgan/README.md
0 → 100644
View file @
7d16fc45
# MaskGAN: Better Text Generation via Filling in the ______
Code for
[
*MaskGAN: Better Text Generation via Filling in the
______*
](
https://arxiv.org/abs/1801.07736
)
published at ICLR 2018.
## Requirements
*
TensorFlow >= v1.3
## Instructions
Warning: The open-source version of this code is still in the process of being
tested. Pretraining may not work correctly.
For training on PTB:
1.
(Optional) Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/.
Instructions WIP.
2.
(Optional) Run MaskGAN in MLE pretraining mode:
```
bash
python train_mask_gan.py
\
--data_dir
=
'/tmp/ptb'
\
--batch_size
=
20
\
--sequence_length
=
20
\
--base_directory
=
'/tmp/maskGAN'
\
--hparams
=
"gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95"
\
--mode
=
'TRAIN'
\
--max_steps
=
100000
\
--language_model_ckpt_dir
=
/tmp/pretrain-lm/
\
--generator_model
=
'seq2seq_vd'
\
--discriminator_model
=
'rnn_zaremba'
\
--is_present_rate
=
0.5
\
--summaries_every
=
10
\
--print_every
=
250
\
--max_num_to_print
=
3
\
--gen_training_strategy
=
cross_entropy
\
--seq2seq_share_embedding
```
3.
Run MaskGAN in GAN mode:
```
bash
python train_mask_gan.py
\
--data_dir
=
'/tmp/ptb'
\
--batch_size
=
128
\
--sequence_length
=
20
\
--base_directory
=
'/tmp/maskGAN'
\
--mask_strategy
=
contiguous
\
--maskgan_ckpt
=
'/tmp/maskGAN'
\
--hparams
=
"gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940"
\
--mode
=
'TRAIN'
\
--max_steps
=
100000
\
--generator_model
=
'seq2seq_vd'
\
--discriminator_model
=
'seq2seq_vd'
\
--is_present_rate
=
0.5
\
--summaries_every
=
250
\
--print_every
=
250
\
--max_num_to_print
=
3
\
--gen_training_strategy
=
'reinforce'
\
--seq2seq_share_embedding
=
true
\
--baseline_method
=
critic
\
--attention_option
=
luong
```
4.
Generate samples:
```
bash
python generate_samples.py
\
--data_dir
/tmp/ptb/
\
--data_set
=
ptb
\
--batch_size
=
256
\
--sequence_length
=
20
\
--base_directory
/tmp/imdbsample/
\
--hparams
=
"gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971"
\
--generator_model
=
seq2seq_vd
\
--discriminator_model
=
seq2seq_vd
\
--is_present_rate
=
0.0
\
--maskgan_ckpt
=
/tmp/maskGAN
\
--seq2seq_share_embedding
=
True
\
--dis_share_embedding
=
True
\
--attention_option
=
luong
\
--mask_strategy
=
contiguous
\
--baseline_method
=
critic
\
--number_epochs
=
4
```
## Contact for Issues
*
Liam Fedus, @liamb315
<liam.fedus@gmail.com>
*
Andrew M. Dai, @a-dai
<adai@google.com>
research/maskgan/data/__init__.py
0 → 100644
View file @
7d16fc45
research/maskgan/data/imdb_loader.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""IMDB data loader and helpers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
tf
.
app
.
flags
.
DEFINE_boolean
(
'prefix_label'
,
True
,
'Vocabulary file.'
)
np
.
set_printoptions
(
precision
=
3
)
np
.
set_printoptions
(
suppress
=
True
)
EOS_INDEX
=
88892
def
_read_words
(
filename
,
use_prefix
=
True
):
all_words
=
[]
sequence_example
=
tf
.
train
.
SequenceExample
()
for
r
in
tf
.
python_io
.
tf_record_iterator
(
filename
):
sequence_example
.
ParseFromString
(
r
)
if
FLAGS
.
prefix_label
and
use_prefix
:
label
=
sequence_example
.
context
.
feature
[
'class'
].
int64_list
.
value
[
0
]
review_words
=
[
EOS_INDEX
+
1
+
label
]
else
:
review_words
=
[]
review_words
.
extend
([
f
.
int64_list
.
value
[
0
]
for
f
in
sequence_example
.
feature_lists
.
feature_list
[
'token_id'
].
feature
])
all_words
.
append
(
review_words
)
return
all_words
def
build_vocab
(
vocab_file
):
word_to_id
=
{}
with
tf
.
gfile
.
GFile
(
vocab_file
,
'r'
)
as
f
:
index
=
0
for
word
in
f
:
word_to_id
[
word
.
strip
()]
=
index
index
+=
1
word_to_id
[
'<eos>'
]
=
EOS_INDEX
return
word_to_id
def
imdb_raw_data
(
data_path
=
None
):
"""Load IMDB raw data from data directory "data_path".
Reads IMDB tf record files containing integer ids,
and performs mini-batching of the inputs.
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data)
where each of the data objects can be passed to IMDBIterator.
"""
train_path
=
os
.
path
.
join
(
data_path
,
'train_lm.tfrecords'
)
valid_path
=
os
.
path
.
join
(
data_path
,
'test_lm.tfrecords'
)
train_data
=
_read_words
(
train_path
)
valid_data
=
_read_words
(
valid_path
)
return
train_data
,
valid_data
def
imdb_iterator
(
raw_data
,
batch_size
,
num_steps
,
epoch_size_override
=
None
):
"""Iterate on the raw IMDB data.
This generates batch_size pointers into the raw IMDB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from imdb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one. The third is a set of weights with 1 indicating a word was
present and 0 not.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
del
epoch_size_override
data_len
=
len
(
raw_data
)
num_batches
=
data_len
//
batch_size
-
1
for
batch
in
range
(
num_batches
):
x
=
np
.
zeros
([
batch_size
,
num_steps
],
dtype
=
np
.
int32
)
y
=
np
.
zeros
([
batch_size
,
num_steps
],
dtype
=
np
.
int32
)
w
=
np
.
zeros
([
batch_size
,
num_steps
],
dtype
=
np
.
float
)
for
i
in
range
(
batch_size
):
data_index
=
batch
*
batch_size
+
i
example
=
raw_data
[
data_index
]
if
len
(
example
)
>
num_steps
:
final_x
=
example
[:
num_steps
]
final_y
=
example
[
1
:(
num_steps
+
1
)]
w
[
i
]
=
1
else
:
to_fill_in
=
num_steps
-
len
(
example
)
final_x
=
example
+
[
EOS_INDEX
]
*
to_fill_in
final_y
=
final_x
[
1
:]
+
[
EOS_INDEX
]
w
[
i
]
=
[
1
]
*
len
(
example
)
+
[
0
]
*
to_fill_in
x
[
i
]
=
final_x
y
[
i
]
=
final_y
yield
(
x
,
y
,
w
)
research/maskgan/data/ptb_loader.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""PTB data loader and helpers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
os
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
EOS_INDEX
=
0
def
_read_words
(
filename
):
with
tf
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
return
f
.
read
().
decode
(
"utf-8"
).
replace
(
"
\n
"
,
"<eos>"
).
split
()
def
build_vocab
(
filename
):
data
=
_read_words
(
filename
)
counter
=
collections
.
Counter
(
data
)
count_pairs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
count_pairs
))
word_to_id
=
dict
(
zip
(
words
,
range
(
len
(
words
))))
print
(
"<eos>:"
,
word_to_id
[
"<eos>"
])
global
EOS_INDEX
EOS_INDEX
=
word_to_id
[
"<eos>"
]
return
word_to_id
def
_file_to_word_ids
(
filename
,
word_to_id
):
data
=
_read_words
(
filename
)
return
[
word_to_id
[
word
]
for
word
in
data
if
word
in
word_to_id
]
def
ptb_raw_data
(
data_path
=
None
):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
train_path
=
os
.
path
.
join
(
data_path
,
"ptb.train.txt"
)
valid_path
=
os
.
path
.
join
(
data_path
,
"ptb.valid.txt"
)
test_path
=
os
.
path
.
join
(
data_path
,
"ptb.test.txt"
)
word_to_id
=
build_vocab
(
train_path
)
train_data
=
_file_to_word_ids
(
train_path
,
word_to_id
)
valid_data
=
_file_to_word_ids
(
valid_path
,
word_to_id
)
test_data
=
_file_to_word_ids
(
test_path
,
word_to_id
)
vocabulary
=
len
(
word_to_id
)
return
train_data
,
valid_data
,
test_data
,
vocabulary
def
ptb_iterator
(
raw_data
,
batch_size
,
num_steps
,
epoch_size_override
=
None
):
"""Iterate on the raw PTB data.
This generates batch_size pointers into the raw PTB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from ptb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
raw_data
=
np
.
array
(
raw_data
,
dtype
=
np
.
int32
)
data_len
=
len
(
raw_data
)
batch_len
=
data_len
//
batch_size
data
=
np
.
full
([
batch_size
,
batch_len
],
EOS_INDEX
,
dtype
=
np
.
int32
)
for
i
in
range
(
batch_size
):
data
[
i
]
=
raw_data
[
batch_len
*
i
:
batch_len
*
(
i
+
1
)]
if
epoch_size_override
:
epoch_size
=
epoch_size_override
else
:
epoch_size
=
(
batch_len
-
1
)
//
num_steps
if
epoch_size
==
0
:
raise
ValueError
(
"epoch_size == 0, decrease batch_size or num_steps"
)
# print("Number of batches per epoch: %d" % epoch_size)
for
i
in
range
(
epoch_size
):
x
=
data
[:,
i
*
num_steps
:(
i
+
1
)
*
num_steps
]
y
=
data
[:,
i
*
num_steps
+
1
:(
i
+
1
)
*
num_steps
+
1
]
w
=
np
.
ones_like
(
x
)
yield
(
x
,
y
,
w
)
research/maskgan/generate_samples.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generate samples from the MaskGAN.
Launch command:
python generate_samples.py
--data_dir=/tmp/data/imdb --data_set=imdb
--batch_size=256 --sequence_length=20 --base_directory=/tmp/imdb
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,
gen_vd_keep_prob=1.0" --generator_model=seq2seq_vd
--discriminator_model=seq2seq_vd --is_present_rate=0.5
--maskgan_ckpt=/tmp/model.ckpt-45494
--seq2seq_share_embedding=True --dis_share_embedding=True
--attention_option=luong --mask_strategy=contiguous --baseline_method=critic
--number_epochs=4
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
functools
import
partial
import
os
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
import
train_mask_gan
from
data
import
imdb_loader
from
data
import
ptb_loader
# Data.
from
model_utils
import
helper
from
model_utils
import
model_utils
SAMPLE_TRAIN
=
'TRAIN'
SAMPLE_VALIDATION
=
'VALIDATION'
## Sample Generation.
## Binary and setup FLAGS.
tf
.
app
.
flags
.
DEFINE_enum
(
'sample_mode'
,
'TRAIN'
,
[
SAMPLE_TRAIN
,
SAMPLE_VALIDATION
],
'Dataset to sample from.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'output_path'
,
'/tmp'
,
'Model output directory.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'output_masked_logs'
,
False
,
'Whether to display for human evaluation (show masking).'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'number_epochs'
,
1
,
'The number of epochs to produce.'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
get_iterator
(
data
):
"""Return the data iterator."""
if
FLAGS
.
data_set
==
'ptb'
:
iterator
=
ptb_loader
.
ptb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
,
FLAGS
.
epoch_size_override
)
elif
FLAGS
.
data_set
==
'imdb'
:
iterator
=
imdb_loader
.
imdb_iterator
(
data
,
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
)
return
iterator
def
convert_to_human_readable
(
id_to_word
,
arr
,
p
,
max_num_to_print
):
"""Convert a np.array of indices into words using id_to_word dictionary.
Return max_num_to_print results.
"""
assert
arr
.
ndim
==
2
samples
=
[]
for
sequence_id
in
xrange
(
min
(
len
(
arr
),
max_num_to_print
)):
sample
=
[]
for
i
,
index
in
enumerate
(
arr
[
sequence_id
,
:]):
if
p
[
sequence_id
,
i
]
==
1
:
sample
.
append
(
str
(
id_to_word
[
index
]))
else
:
sample
.
append
(
'*'
+
str
(
id_to_word
[
index
]))
buffer_str
=
' '
.
join
(
sample
)
samples
.
append
(
buffer_str
)
return
samples
def
write_unmasked_log
(
log
,
id_to_word
,
sequence_eval
):
"""Helper function for logging evaluated sequences without mask."""
indices_arr
=
np
.
asarray
(
sequence_eval
)
samples
=
helper
.
convert_to_human_readable
(
id_to_word
,
indices_arr
,
FLAGS
.
batch_size
)
for
sample
in
samples
:
log
.
write
(
sample
+
'
\n
'
)
log
.
flush
()
return
samples
def
write_masked_log
(
log
,
id_to_word
,
sequence_eval
,
present_eval
):
indices_arr
=
np
.
asarray
(
sequence_eval
)
samples
=
convert_to_human_readable
(
id_to_word
,
indices_arr
,
present_eval
,
FLAGS
.
batch_size
)
for
sample
in
samples
:
log
.
write
(
sample
+
'
\n
'
)
log
.
flush
()
return
samples
def
generate_logs
(
sess
,
model
,
log
,
id_to_word
,
feed
):
"""Impute Sequences using the model for a particular feed and send it to
logs.
"""
# Impute Sequences.
[
p
,
inputs_eval
,
sequence_eval
]
=
sess
.
run
(
[
model
.
present
,
model
.
inputs
,
model
.
fake_sequence
],
feed_dict
=
feed
)
# Add the 0th time-step for coherence.
first_token
=
np
.
expand_dims
(
inputs_eval
[:,
0
],
axis
=
1
)
sequence_eval
=
np
.
concatenate
((
first_token
,
sequence_eval
),
axis
=
1
)
# 0th token always present.
p
=
np
.
concatenate
((
np
.
ones
((
FLAGS
.
batch_size
,
1
)),
p
),
axis
=
1
)
if
FLAGS
.
output_masked_logs
:
samples
=
write_masked_log
(
log
,
id_to_word
,
sequence_eval
,
p
)
else
:
samples
=
write_unmasked_log
(
log
,
id_to_word
,
sequence_eval
)
return
samples
def
generate_samples
(
hparams
,
data
,
id_to_word
,
log_dir
,
output_file
):
""""Generate samples.
Args:
hparams: Hyperparameters for the MaskGAN.
data: Data to evaluate.
id_to_word: Dictionary of indices to words.
log_dir: Log directory.
output_file: Output file for the samples.
"""
# Boolean indicating operational mode.
is_training
=
False
# Set a random seed to keep fixed mask.
np
.
random
.
seed
(
0
)
with
tf
.
Graph
().
as_default
():
# Construct the model.
model
=
train_mask_gan
.
create_MaskGAN
(
hparams
,
is_training
)
## Retrieve the initial savers.
init_savers
=
model_utils
.
retrieve_init_savers
(
hparams
)
## Initial saver function to supervisor.
init_fn
=
partial
(
model_utils
.
init_fn
,
init_savers
)
is_chief
=
FLAGS
.
task
==
0
# Create the supervisor. It will take care of initialization, summaries,
# checkpoints, and recovery.
sv
=
tf
.
Supervisor
(
logdir
=
log_dir
,
is_chief
=
is_chief
,
saver
=
model
.
saver
,
global_step
=
model
.
global_step
,
recovery_wait_secs
=
30
,
summary_op
=
None
,
init_fn
=
init_fn
)
# Get an initialized, and possibly recovered session. Launch the
# services: Checkpointing, Summaries, step counting.
#
# When multiple replicas of this program are running the services are
# only launched by the 'chief' replica.
with
sv
.
managed_session
(
FLAGS
.
master
,
start_standard_services
=
False
)
as
sess
:
# Generator statefulness over the epoch.
[
gen_initial_state_eval
,
fake_gen_initial_state_eval
]
=
sess
.
run
(
[
model
.
eval_initial_state
,
model
.
fake_gen_initial_state
])
for
n
in
xrange
(
FLAGS
.
number_epochs
):
print
(
'Epoch number: %d'
%
n
)
# print('Percent done: %.2f' % float(n) / float(FLAGS.number_epochs))
iterator
=
get_iterator
(
data
)
for
x
,
y
,
_
in
iterator
:
if
FLAGS
.
eval_language_model
:
is_present_rate
=
0.
else
:
is_present_rate
=
FLAGS
.
is_present_rate
tf
.
logging
.
info
(
'Evaluating on is_present_rate=%.3f.'
%
is_present_rate
)
model_utils
.
assign_percent_real
(
sess
,
model
.
percent_real_update
,
model
.
new_rate
,
is_present_rate
)
# Randomly mask out tokens.
p
=
model_utils
.
generate_mask
()
eval_feed
=
{
model
.
inputs
:
x
,
model
.
targets
:
y
,
model
.
present
:
p
}
if
FLAGS
.
data_set
==
'ptb'
:
# Statefulness for *evaluation* Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
eval_initial_state
):
eval_feed
[
c
]
=
gen_initial_state_eval
[
i
].
c
eval_feed
[
h
]
=
gen_initial_state_eval
[
i
].
h
# Statefulness for the Generator.
for
i
,
(
c
,
h
)
in
enumerate
(
model
.
fake_gen_initial_state
):
eval_feed
[
c
]
=
fake_gen_initial_state_eval
[
i
].
c
eval_feed
[
h
]
=
fake_gen_initial_state_eval
[
i
].
h
[
gen_initial_state_eval
,
fake_gen_initial_state_eval
,
_
]
=
sess
.
run
(
[
model
.
eval_final_state
,
model
.
fake_gen_final_state
,
model
.
global_step
],
feed_dict
=
eval_feed
)
generate_logs
(
sess
,
model
,
output_file
,
id_to_word
,
eval_feed
)
output_file
.
close
()
print
(
'Closing output_file.'
)
return
def
main
(
_
):
hparams
=
train_mask_gan
.
create_hparams
()
log_dir
=
FLAGS
.
base_directory
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_path
)
output_file
=
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
FLAGS
.
output_path
,
'reviews.txt'
),
mode
=
'w'
)
# Load data set.
if
FLAGS
.
data_set
==
'ptb'
:
raw_data
=
ptb_loader
.
ptb_raw_data
(
FLAGS
.
data_dir
)
train_data
,
valid_data
,
_
,
_
=
raw_data
elif
FLAGS
.
data_set
==
'imdb'
:
raw_data
=
imdb_loader
.
imdb_raw_data
(
FLAGS
.
data_dir
)
train_data
,
valid_data
=
raw_data
else
:
raise
NotImplementedError
# Generating more data on train set.
if
FLAGS
.
sample_mode
==
SAMPLE_TRAIN
:
data_set
=
train_data
elif
FLAGS
.
sample_mode
==
SAMPLE_VALIDATION
:
data_set
=
valid_data
else
:
raise
NotImplementedError
# Dictionary and reverse dictionry.
if
FLAGS
.
data_set
==
'ptb'
:
word_to_id
=
ptb_loader
.
build_vocab
(
os
.
path
.
join
(
FLAGS
.
data_dir
,
'ptb.train.txt'
))
elif
FLAGS
.
data_set
==
'imdb'
:
word_to_id
=
imdb_loader
.
build_vocab
(
os
.
path
.
join
(
FLAGS
.
data_dir
,
'vocab.txt'
))
id_to_word
=
{
v
:
k
for
k
,
v
in
word_to_id
.
iteritems
()}
FLAGS
.
vocab_size
=
len
(
id_to_word
)
print
(
'Vocab size: %d'
%
FLAGS
.
vocab_size
)
generate_samples
(
hparams
,
data_set
,
id_to_word
,
log_dir
,
output_file
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/maskgan/losses/__init__.py
0 → 100644
View file @
7d16fc45
research/maskgan/losses/losses.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses for Generator and Discriminator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
discriminator_loss
(
predictions
,
labels
,
missing_tokens
):
"""Discriminator loss based on predictions and labels.
Args:
predictions: Discriminator linear predictions Tensor of shape [batch_size,
sequence_length]
labels: Labels for predictions, Tensor of shape [batch_size,
sequence_length]
missing_tokens: Indicator for the missing tokens. Evaluate the loss only
on the tokens that were missing.
Returns:
loss: Scalar tf.float32 loss.
"""
loss
=
tf
.
losses
.
sigmoid_cross_entropy
(
labels
,
predictions
,
weights
=
missing_tokens
)
loss
=
tf
.
Print
(
loss
,
[
loss
,
labels
,
missing_tokens
],
message
=
'loss, labels, missing_tokens'
,
summarize
=
25
,
first_n
=
25
)
return
loss
def
cross_entropy_loss_matrix
(
gen_labels
,
gen_logits
):
"""Computes the cross entropy loss for G.
Args:
gen_labels: Labels for the correct token.
gen_logits: Generator logits.
Returns:
loss_matrix: Loss matrix of shape [batch_size, sequence_length].
"""
cross_entropy_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
gen_labels
,
logits
=
gen_logits
)
return
cross_entropy_loss
def
GAN_loss_matrix
(
dis_predictions
):
"""Computes the cross entropy loss for G.
Args:
dis_predictions: Discriminator predictions.
Returns:
loss_matrix: Loss matrix of shape [batch_size, sequence_length].
"""
eps
=
tf
.
constant
(
1e-7
,
tf
.
float32
)
gan_loss_matrix
=
-
tf
.
log
(
dis_predictions
+
eps
)
return
gan_loss_matrix
def
generator_GAN_loss
(
predictions
):
"""Generator GAN loss based on Discriminator predictions."""
return
-
tf
.
log
(
tf
.
reduce_mean
(
predictions
))
def
generator_blended_forward_loss
(
gen_logits
,
gen_labels
,
dis_predictions
,
is_real_input
):
"""Computes the masked-loss for G. This will be a blend of cross-entropy
loss where the true label is known and GAN loss where the true label has been
masked.
Args:
gen_logits: Generator logits.
gen_labels: Labels for the correct token.
dis_predictions: Discriminator predictions.
is_real_input: Tensor indicating whether the label is present.
Returns:
loss: Scalar tf.float32 total loss.
"""
cross_entropy_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
gen_labels
,
logits
=
gen_logits
)
gan_loss
=
-
tf
.
log
(
dis_predictions
)
loss_matrix
=
tf
.
where
(
is_real_input
,
cross_entropy_loss
,
gan_loss
)
return
tf
.
reduce_mean
(
loss_matrix
)
def
wasserstein_generator_loss
(
gen_logits
,
gen_labels
,
dis_values
,
is_real_input
):
"""Computes the masked-loss for G. This will be a blend of cross-entropy
loss where the true label is known and GAN loss where the true label is
missing.
Args:
gen_logits: Generator logits.
gen_labels: Labels for the correct token.
dis_values: Discriminator values Tensor of shape [batch_size,
sequence_length].
is_real_input: Tensor indicating whether the label is present.
Returns:
loss: Scalar tf.float32 total loss.
"""
cross_entropy_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
gen_labels
,
logits
=
gen_logits
)
# Maximize the dis_values (minimize the negative)
gan_loss
=
-
dis_values
loss_matrix
=
tf
.
where
(
is_real_input
,
cross_entropy_loss
,
gan_loss
)
loss
=
tf
.
reduce_mean
(
loss_matrix
)
return
loss
def
wasserstein_discriminator_loss
(
real_values
,
fake_values
):
"""Wasserstein discriminator loss.
Args:
real_values: Value given by the Wasserstein Discriminator to real data.
fake_values: Value given by the Wasserstein Discriminator to fake data.
Returns:
loss: Scalar tf.float32 loss.
"""
real_avg
=
tf
.
reduce_mean
(
real_values
)
fake_avg
=
tf
.
reduce_mean
(
fake_values
)
wasserstein_loss
=
real_avg
-
fake_avg
return
wasserstein_loss
def
wasserstein_discriminator_loss_intrabatch
(
values
,
is_real_input
):
"""Wasserstein discriminator loss. This is an odd variant where the value
difference is between the real tokens and the fake tokens within a single
batch.
Args:
values: Value given by the Wasserstein Discriminator of shape [batch_size,
sequence_length] to an imputed batch (real and fake).
is_real_input: tf.bool Tensor of shape [batch_size, sequence_length]. If
true, it indicates that the label is known.
Returns:
wasserstein_loss: Scalar tf.float32 loss.
"""
zero_tensor
=
tf
.
constant
(
0.
,
dtype
=
tf
.
float32
,
shape
=
[])
present
=
tf
.
cast
(
is_real_input
,
tf
.
float32
)
missing
=
tf
.
cast
(
1
-
present
,
tf
.
float32
)
# Counts for real and fake tokens.
real_count
=
tf
.
reduce_sum
(
present
)
fake_count
=
tf
.
reduce_sum
(
missing
)
# Averages for real and fake token values.
real
=
tf
.
mul
(
values
,
present
)
fake
=
tf
.
mul
(
values
,
missing
)
real_avg
=
tf
.
reduce_sum
(
real
)
/
real_count
fake_avg
=
tf
.
reduce_sum
(
fake
)
/
fake_count
# If there are no real or fake entries in the batch, we assign an average
# value of zero.
real_avg
=
tf
.
where
(
tf
.
equal
(
real_count
,
0
),
zero_tensor
,
real_avg
)
fake_avg
=
tf
.
where
(
tf
.
equal
(
fake_count
,
0
),
zero_tensor
,
fake_avg
)
wasserstein_loss
=
real_avg
-
fake_avg
return
wasserstein_loss
research/maskgan/model_utils/__init__.py
0 → 100644
View file @
7d16fc45
research/maskgan/model_utils/helper.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Random helper functions for converting between indices and one-hot encodings
as well as printing/logging helpers.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
def
variable_summaries
(
var
,
name
):
"""Attach a lot of summaries to a Tensor."""
mean
=
tf
.
reduce_mean
(
var
)
tf
.
summary
.
scalar
(
'mean/'
+
name
,
mean
)
with
tf
.
name_scope
(
'stddev'
):
stddev
=
tf
.
sqrt
(
tf
.
reduce_sum
(
tf
.
square
(
var
-
mean
)))
tf
.
summary
.
scalar
(
'sttdev/'
+
name
,
stddev
)
tf
.
summary
.
scalar
(
'max/'
+
name
,
tf
.
reduce_max
(
var
))
tf
.
summary
.
scalar
(
'min/'
+
name
,
tf
.
reduce_min
(
var
))
tf
.
summary
.
histogram
(
name
,
var
)
def
zip_seq_pred_crossent
(
id_to_word
,
sequences
,
predictions
,
cross_entropy
):
"""Zip together the sequences, predictions, cross entropy."""
indices
=
convert_to_indices
(
sequences
)
batch_of_metrics
=
[]
for
ind_batch
,
pred_batch
,
crossent_batch
in
zip
(
indices
,
predictions
,
cross_entropy
):
metrics
=
[]
for
index
,
pred
,
crossent
in
zip
(
ind_batch
,
pred_batch
,
crossent_batch
):
metrics
.
append
([
str
(
id_to_word
[
index
]),
pred
,
crossent
])
batch_of_metrics
.
append
(
metrics
)
return
batch_of_metrics
def
print_and_log
(
log
,
id_to_word
,
sequence_eval
,
max_num_to_print
=
5
):
"""Helper function for printing and logging evaluated sequences."""
indices_eval
=
convert_to_indices
(
sequence_eval
)
indices_arr
=
np
.
asarray
(
indices_eval
)
samples
=
convert_to_human_readable
(
id_to_word
,
indices_arr
,
max_num_to_print
)
for
i
,
sample
in
enumerate
(
samples
):
print
(
'Sample'
,
i
,
'. '
,
sample
)
log
.
write
(
'
\n
Sample '
+
str
(
i
)
+
'. '
+
sample
)
log
.
write
(
'
\n
'
)
print
(
'
\n
'
)
log
.
flush
()
def
convert_to_human_readable
(
id_to_word
,
arr
,
max_num_to_print
):
"""Convert a np.array of indices into words using id_to_word dictionary.
Return max_num_to_print results.
"""
assert
arr
.
ndim
==
2
samples
=
[]
for
sequence_id
in
xrange
(
min
(
len
(
arr
),
max_num_to_print
)):
buffer_str
=
' '
.
join
(
[
str
(
id_to_word
[
index
])
for
index
in
arr
[
sequence_id
,
:]])
samples
.
append
(
buffer_str
)
return
samples
def
index_to_vocab_array
(
indices
,
vocab_size
,
sequence_length
):
"""Convert the indices into an array with vocab_size one-hot encoding."""
# Extract properties of the indices.
num_batches
=
len
(
indices
)
shape
=
list
(
indices
.
shape
)
shape
.
append
(
vocab_size
)
# Construct the vocab_size array.
new_arr
=
np
.
zeros
(
shape
)
for
n
in
xrange
(
num_batches
):
indices_batch
=
indices
[
n
]
new_arr_batch
=
new_arr
[
n
]
# We map all indices greater than the vocabulary size to an unknown
# character.
indices_batch
=
np
.
where
(
indices_batch
<
vocab_size
,
indices_batch
,
vocab_size
-
1
)
# Convert indices to vocab_size dimensions.
new_arr_batch
[
np
.
arange
(
sequence_length
),
indices_batch
]
=
1
return
new_arr
def
convert_to_indices
(
sequences
):
"""Convert a list of size [batch_size, sequence_length, vocab_size] to
a list of size [batch_size, sequence_length] where the vocab element is
denoted by the index.
"""
batch_of_indices
=
[]
for
sequence
in
sequences
:
indices
=
[]
for
embedding
in
sequence
:
indices
.
append
(
np
.
argmax
(
embedding
))
batch_of_indices
.
append
(
indices
)
return
batch_of_indices
def
convert_and_zip
(
id_to_word
,
sequences
,
predictions
):
"""Helper function for printing or logging. Retrieves list of sequences
and predictions and zips them together.
"""
indices
=
convert_to_indices
(
sequences
)
batch_of_indices_predictions
=
[]
for
index_batch
,
pred_batch
in
zip
(
indices
,
predictions
):
indices_predictions
=
[]
for
index
,
pred
in
zip
(
index_batch
,
pred_batch
):
indices_predictions
.
append
([
str
(
id_to_word
[
index
]),
pred
])
batch_of_indices_predictions
.
append
(
indices_predictions
)
return
batch_of_indices_predictions
def
recursive_length
(
item
):
"""Recursively determine the total number of elements in nested list."""
if
type
(
item
)
==
list
:
return
sum
(
recursive_length
(
subitem
)
for
subitem
in
item
)
else
:
return
1.
def
percent_correct
(
real_sequence
,
fake_sequences
):
"""Determine the percent of tokens correctly generated within a batch."""
identical
=
0.
for
fake_sequence
in
fake_sequences
:
for
real
,
fake
in
zip
(
real_sequence
,
fake_sequence
):
if
real
==
fake
:
identical
+=
1.
return
identical
/
recursive_length
(
fake_sequences
)
research/maskgan/model_utils/model_construction.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model construction."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
tensorflow
as
tf
from
models
import
bidirectional
from
models
import
bidirectional_vd
from
models
import
bidirectional_zaremba
from
models
import
cnn
from
models
import
critic_vd
from
models
import
feedforward
from
models
import
rnn
from
models
import
rnn_nas
from
models
import
rnn_vd
from
models
import
rnn_zaremba
from
models
import
seq2seq
from
models
import
seq2seq_nas
from
models
import
seq2seq_vd
from
models
import
seq2seq_zaremba
FLAGS
=
tf
.
app
.
flags
.
FLAGS
# TODO(adai): IMDB labels placeholder to model.
def
create_generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
,
is_validating
,
reuse
=
None
):
"""Create the Generator model specified by the FLAGS and hparams.
Args;
hparams: Hyperparameters for the MaskGAN.
inputs: tf.int32 Tensor of the sequence input of shape [batch_size,
sequence_length].
present: tf.bool Tensor indicating the presence or absence of the token
of shape [batch_size, sequence_length].
is_training: Whether the model is training.
is_validating: Whether the model is being run in validation mode for
calculating the perplexity.
reuse (Optional): Whether to reuse the model.
Returns:
Tuple of the (sequence, logits, log_probs) of the Generator. Sequence
and logits have shape [batch_size, sequence_length, vocab_size]. The
log_probs will have shape [batch_size, sequence_length]. Log_probs
corresponds to the log probability of selecting the words.
"""
if
FLAGS
.
generator_model
==
'rnn'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
rnn
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'rnn_zaremba'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
rnn_zaremba
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'seq2seq'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
seq2seq
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'seq2seq_zaremba'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
seq2seq_zaremba
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'rnn_nas'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
rnn_nas
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'seq2seq_nas'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
)
=
seq2seq_nas
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
elif
FLAGS
.
generator_model
==
'seq2seq_vd'
:
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
,
encoder_states
)
=
seq2seq_vd
.
generator
(
hparams
,
inputs
,
targets
,
present
,
is_training
=
is_training
,
is_validating
=
is_validating
,
reuse
=
reuse
)
else
:
raise
NotImplementedError
return
(
sequence
,
logits
,
log_probs
,
initial_state
,
final_state
,
encoder_states
)
def
create_discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
,
initial_state
=
None
,
inputs
=
None
,
present
=
None
):
"""Create the Discriminator model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
predictions: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if
FLAGS
.
discriminator_model
==
'cnn'
:
predictions
=
cnn
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'fnn'
:
predictions
=
feedforward
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'rnn'
:
predictions
=
rnn
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'bidirectional'
:
predictions
=
bidirectional
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'bidirectional_zaremba'
:
predictions
=
bidirectional_zaremba
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
predictions
=
seq2seq_vd
.
discriminator
(
hparams
,
inputs
,
present
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'rnn_zaremba'
:
predictions
=
rnn_zaremba
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'rnn_nas'
:
predictions
=
rnn_nas
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
)
elif
FLAGS
.
discriminator_model
==
'rnn_vd'
:
predictions
=
rnn_vd
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
,
initial_state
=
initial_state
)
elif
FLAGS
.
discriminator_model
==
'bidirectional_vd'
:
predictions
=
bidirectional_vd
.
discriminator
(
hparams
,
sequence
,
is_training
=
is_training
,
reuse
=
reuse
,
initial_state
=
initial_state
)
else
:
raise
NotImplementedError
return
predictions
def
create_critic
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Create the Critic model specified by the FLAGS and hparams.
Args:
hparams: Hyperparameters for the MaskGAN.
sequence: tf.int32 Tensor sequence of shape [batch_size, sequence_length]
is_training: Whether the model is training.
reuse (Optional): Whether to reuse the model.
Returns:
values: tf.float32 Tensor of predictions of shape [batch_size,
sequence_length]
"""
if
FLAGS
.
baseline_method
==
'critic'
:
if
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
values
=
critic_vd
.
critic_seq2seq_vd_derivative
(
hparams
,
sequence
,
is_training
,
reuse
=
reuse
)
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
return
values
research/maskgan/model_utils/model_losses.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model loss construction."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
# Useful for REINFORCE baseline.
from
losses
import
losses
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
create_dis_loss
(
fake_predictions
,
real_predictions
,
targets_present
):
"""Compute Discriminator loss across real/fake."""
missing
=
tf
.
cast
(
targets_present
,
tf
.
int32
)
missing
=
1
-
missing
missing
=
tf
.
cast
(
missing
,
tf
.
bool
)
real_labels
=
tf
.
ones
([
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
])
dis_loss_real
=
tf
.
losses
.
sigmoid_cross_entropy
(
real_labels
,
real_predictions
,
weights
=
missing
)
dis_loss_fake
=
tf
.
losses
.
sigmoid_cross_entropy
(
targets_present
,
fake_predictions
,
weights
=
missing
)
dis_loss
=
(
dis_loss_fake
+
dis_loss_real
)
/
2.
return
dis_loss
,
dis_loss_fake
,
dis_loss_real
def
create_critic_loss
(
cumulative_rewards
,
estimated_values
,
present
):
"""Compute Critic loss in estimating the value function. This should be an
estimate only for the missing elements."""
missing
=
tf
.
cast
(
present
,
tf
.
int32
)
missing
=
1
-
missing
missing
=
tf
.
cast
(
missing
,
tf
.
bool
)
loss
=
tf
.
losses
.
mean_squared_error
(
labels
=
cumulative_rewards
,
predictions
=
estimated_values
,
weights
=
missing
)
return
loss
def
create_masked_cross_entropy_loss
(
targets
,
present
,
logits
):
"""Calculate the cross entropy loss matrices for the masked tokens."""
cross_entropy_losses
=
losses
.
cross_entropy_loss_matrix
(
targets
,
logits
)
# Zeros matrix.
zeros_losses
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
],
dtype
=
tf
.
float32
)
missing_ce_loss
=
tf
.
where
(
present
,
zeros_losses
,
cross_entropy_losses
)
return
missing_ce_loss
def
calculate_reinforce_objective
(
hparams
,
log_probs
,
dis_predictions
,
present
,
estimated_values
=
None
):
"""Calculate the REINFORCE objectives. The REINFORCE objective should
only be on the tokens that were missing. Specifically, the final Generator
reward should be based on the Discriminator predictions on missing tokens.
The log probaibilities should be only for missing tokens and the baseline
should be calculated only on the missing tokens.
For this model, we optimize the reward is the log of the *conditional*
probability the Discriminator assigns to the distribution. Specifically, for
a Discriminator D which outputs probability of real, given the past context,
r_t = log D(x_t|x_0,x_1,...x_{t-1})
And the policy for Generator G is the log-probability of taking action x2
given the past context.
Args:
hparams: MaskGAN hyperparameters.
log_probs: tf.float32 Tensor of log probailities of the tokens selected by
the Generator. Shape [batch_size, sequence_length].
dis_predictions: tf.float32 Tensor of the predictions from the
Discriminator. Shape [batch_size, sequence_length].
present: tf.bool Tensor indicating which tokens are present. Shape
[batch_size, sequence_length].
estimated_values: tf.float32 Tensor of estimated state values of tokens.
Shape [batch_size, sequence_length]
Returns:
final_gen_objective: Final REINFORCE objective for the sequence.
rewards: tf.float32 Tensor of rewards for sequence of shape [batch_size,
sequence_length]
advantages: tf.float32 Tensor of advantages for sequence of shape
[batch_size, sequence_length]
baselines: tf.float32 Tensor of baselines for sequence of shape
[batch_size, sequence_length]
maintain_averages_op: ExponentialMovingAverage apply average op to
maintain the baseline.
"""
# Final Generator objective.
final_gen_objective
=
0.
gamma
=
hparams
.
rl_discount_rate
eps
=
1e-7
# Generator rewards are log-probabilities.
eps
=
tf
.
constant
(
1e-7
,
tf
.
float32
)
dis_predictions
=
tf
.
nn
.
sigmoid
(
dis_predictions
)
rewards
=
tf
.
log
(
dis_predictions
+
eps
)
# Apply only for missing elements.
zeros
=
tf
.
zeros_like
(
present
,
dtype
=
tf
.
float32
)
log_probs
=
tf
.
where
(
present
,
zeros
,
log_probs
)
rewards
=
tf
.
where
(
present
,
zeros
,
rewards
)
# Unstack Tensors into lists.
rewards_list
=
tf
.
unstack
(
rewards
,
axis
=
1
)
log_probs_list
=
tf
.
unstack
(
log_probs
,
axis
=
1
)
missing
=
1.
-
tf
.
cast
(
present
,
tf
.
float32
)
missing_list
=
tf
.
unstack
(
missing
,
axis
=
1
)
# Cumulative Discounted Returns. The true value function V*(s).
cumulative_rewards
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
cum_value
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
])
for
s
in
xrange
(
t
,
FLAGS
.
sequence_length
):
cum_value
+=
missing_list
[
s
]
*
np
.
power
(
gamma
,
(
s
-
t
))
*
rewards_list
[
s
]
cumulative_rewards
.
append
(
cum_value
)
cumulative_rewards
=
tf
.
stack
(
cumulative_rewards
,
axis
=
1
)
## REINFORCE with different baselines.
# We create a separate critic functionality for the Discriminator. This
# will need to operate unidirectionally and it may take in the past context.
if
FLAGS
.
baseline_method
==
'critic'
:
# Critic loss calculated from the estimated value function \hat{V}(s)
# versus the true value function V*(s).
critic_loss
=
create_critic_loss
(
cumulative_rewards
,
estimated_values
,
present
)
# Baselines are coming from the critic's estimated state values.
baselines
=
tf
.
unstack
(
estimated_values
,
axis
=
1
)
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
log_probability
=
log_probs_list
[
t
]
cum_advantage
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
])
for
s
in
xrange
(
t
,
FLAGS
.
sequence_length
):
cum_advantage
+=
missing_list
[
s
]
*
np
.
power
(
gamma
,
(
s
-
t
))
*
rewards_list
[
s
]
cum_advantage
-=
baselines
[
t
]
# Clip advantages.
cum_advantage
=
tf
.
clip_by_value
(
cum_advantage
,
-
FLAGS
.
advantage_clipping
,
FLAGS
.
advantage_clipping
)
advantages
.
append
(
missing_list
[
t
]
*
cum_advantage
)
final_gen_objective
+=
tf
.
multiply
(
log_probability
,
missing_list
[
t
]
*
tf
.
stop_gradient
(
cum_advantage
))
maintain_averages_op
=
None
baselines
=
tf
.
stack
(
baselines
,
axis
=
1
)
advantages
=
tf
.
stack
(
advantages
,
axis
=
1
)
# Split the batch into half. Use half for MC estimates for REINFORCE.
# Use the other half to establish a baseline.
elif
FLAGS
.
baseline_method
==
'dis_batch'
:
# TODO(liamfedus): Recheck.
[
rewards_half
,
baseline_half
]
=
tf
.
split
(
rewards
,
num_or_size_splits
=
2
,
axis
=
0
)
[
log_probs_half
,
_
]
=
tf
.
split
(
log_probs
,
num_or_size_splits
=
2
,
axis
=
0
)
[
reward_present_half
,
baseline_present_half
]
=
tf
.
split
(
present
,
num_or_size_splits
=
2
,
axis
=
0
)
# Unstack to lists.
baseline_list
=
tf
.
unstack
(
baseline_half
,
axis
=
1
)
baseline_missing
=
1.
-
tf
.
cast
(
baseline_present_half
,
tf
.
float32
)
baseline_missing_list
=
tf
.
unstack
(
baseline_missing
,
axis
=
1
)
baselines
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
# Calculate baseline only for missing tokens.
num_missing
=
tf
.
reduce_sum
(
baseline_missing_list
[
t
])
avg_baseline
=
tf
.
reduce_sum
(
baseline_missing_list
[
t
]
*
baseline_list
[
t
],
keep_dims
=
True
)
/
(
num_missing
+
eps
)
baseline
=
tf
.
tile
(
avg_baseline
,
multiples
=
[
FLAGS
.
batch_size
/
2
])
baselines
.
append
(
baseline
)
# Unstack to lists.
rewards_list
=
tf
.
unstack
(
rewards_half
,
axis
=
1
)
log_probs_list
=
tf
.
unstack
(
log_probs_half
,
axis
=
1
)
reward_missing
=
1.
-
tf
.
cast
(
reward_present_half
,
tf
.
float32
)
reward_missing_list
=
tf
.
unstack
(
reward_missing
,
axis
=
1
)
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
log_probability
=
log_probs_list
[
t
]
cum_advantage
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
/
2
])
for
s
in
xrange
(
t
,
FLAGS
.
sequence_length
):
cum_advantage
+=
reward_missing_list
[
s
]
*
np
.
power
(
gamma
,
(
s
-
t
))
*
(
rewards_list
[
s
]
-
baselines
[
s
])
# Clip advantages.
cum_advantage
=
tf
.
clip_by_value
(
cum_advantage
,
-
FLAGS
.
advantage_clipping
,
FLAGS
.
advantage_clipping
)
advantages
.
append
(
reward_missing_list
[
t
]
*
cum_advantage
)
final_gen_objective
+=
tf
.
multiply
(
log_probability
,
reward_missing_list
[
t
]
*
tf
.
stop_gradient
(
cum_advantage
))
# Cumulative Discounted Returns. The true value function V*(s).
cumulative_rewards
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
cum_value
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
/
2
])
for
s
in
xrange
(
t
,
FLAGS
.
sequence_length
):
cum_value
+=
reward_missing_list
[
s
]
*
np
.
power
(
gamma
,
(
s
-
t
))
*
rewards_list
[
s
]
cumulative_rewards
.
append
(
cum_value
)
cumulative_rewards
=
tf
.
stack
(
cumulative_rewards
,
axis
=
1
)
rewards
=
rewards_half
critic_loss
=
None
maintain_averages_op
=
None
baselines
=
tf
.
stack
(
baselines
,
axis
=
1
)
advantages
=
tf
.
stack
(
advantages
,
axis
=
1
)
# Exponential Moving Average baseline.
elif
FLAGS
.
baseline_method
==
'ema'
:
# TODO(liamfedus): Recheck.
# Lists of rewards and Log probabilities of the actions taken only for
# missing tokens.
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
hparams
.
baseline_decay
)
maintain_averages_op
=
ema
.
apply
(
rewards_list
)
baselines
=
[]
for
r
in
rewards_list
:
baselines
.
append
(
ema
.
average
(
r
))
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s).
advantages
=
[]
for
t
in
xrange
(
FLAGS
.
sequence_length
):
log_probability
=
log_probs_list
[
t
]
# Calculate the forward advantage only on the missing tokens.
cum_advantage
=
tf
.
zeros
(
shape
=
[
FLAGS
.
batch_size
])
for
s
in
xrange
(
t
,
FLAGS
.
sequence_length
):
cum_advantage
+=
missing_list
[
s
]
*
np
.
power
(
gamma
,
(
s
-
t
))
*
(
rewards_list
[
s
]
-
baselines
[
s
])
# Clip advantages.
cum_advantage
=
tf
.
clip_by_value
(
cum_advantage
,
-
FLAGS
.
advantage_clipping
,
FLAGS
.
advantage_clipping
)
advantages
.
append
(
missing_list
[
t
]
*
cum_advantage
)
final_gen_objective
+=
tf
.
multiply
(
log_probability
,
missing_list
[
t
]
*
tf
.
stop_gradient
(
cum_advantage
))
critic_loss
=
None
baselines
=
tf
.
stack
(
baselines
,
axis
=
1
)
advantages
=
tf
.
stack
(
advantages
,
axis
=
1
)
elif
FLAGS
.
baseline_method
is
None
:
num_missing
=
tf
.
reduce_sum
(
missing
)
final_gen_objective
+=
tf
.
reduce_sum
(
rewards
)
/
(
num_missing
+
eps
)
baselines
=
tf
.
zeros_like
(
rewards
)
critic_loss
=
None
maintain_averages_op
=
None
advantages
=
cumulative_rewards
else
:
raise
NotImplementedError
return
[
final_gen_objective
,
log_probs
,
rewards
,
advantages
,
baselines
,
maintain_averages_op
,
critic_loss
,
cumulative_rewards
]
def
calculate_log_perplexity
(
logits
,
targets
,
present
):
"""Calculate the average log perplexity per *missing* token.
Args:
logits: tf.float32 Tensor of the logits of shape [batch_size,
sequence_length, vocab_size].
targets: tf.int32 Tensor of the sequence target of shape [batch_size,
sequence_length].
present: tf.bool Tensor indicating the presence or absence of the token
of shape [batch_size, sequence_length].
Returns:
avg_log_perplexity: Scalar indicating the average log perplexity per
missing token in the batch.
"""
# logits = tf.Print(logits, [logits], message='logits:', summarize=50)
# targets = tf.Print(targets, [targets], message='targets:', summarize=50)
eps
=
1e-12
logits
=
tf
.
reshape
(
logits
,
[
-
1
,
FLAGS
.
vocab_size
])
# Only calculate log-perplexity on missing tokens.
weights
=
tf
.
cast
(
present
,
tf
.
float32
)
weights
=
1.
-
weights
weights
=
tf
.
reshape
(
weights
,
[
-
1
])
num_missing
=
tf
.
reduce_sum
(
weights
)
log_perplexity
=
tf
.
contrib
.
legacy_seq2seq
.
sequence_loss_by_example
(
[
logits
],
[
tf
.
reshape
(
targets
,
[
-
1
])],
[
weights
])
avg_log_perplexity
=
tf
.
reduce_sum
(
log_perplexity
)
/
(
num_missing
+
eps
)
return
avg_log_perplexity
research/maskgan/model_utils/model_optimization.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model optimization."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
create_dis_pretrain_op
(
hparams
,
dis_loss
,
global_step
):
"""Create a train op for pretraining."""
with
tf
.
name_scope
(
'pretrain_generator'
):
optimizer
=
tf
.
train
.
AdamOptimizer
(
hparams
.
dis_pretrain_learning_rate
)
dis_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'dis'
)
]
if
FLAGS
.
dis_update_share_embedding
and
FLAGS
.
dis_share_embedding
:
shared_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/embedding'
][
0
]
dis_vars
.
append
(
shared_embedding
)
dis_grads
=
tf
.
gradients
(
dis_loss
,
dis_vars
)
dis_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
dis_grads
,
FLAGS
.
grad_clipping
)
dis_pretrain_op
=
optimizer
.
apply_gradients
(
zip
(
dis_grads_clipped
,
dis_vars
),
global_step
=
global_step
)
return
dis_pretrain_op
def
create_gen_pretrain_op
(
hparams
,
cross_entropy_loss
,
global_step
):
"""Create a train op for pretraining."""
with
tf
.
name_scope
(
'pretrain_generator'
):
optimizer
=
tf
.
train
.
AdamOptimizer
(
hparams
.
gen_pretrain_learning_rate
)
gen_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'gen'
)
]
gen_grads
=
tf
.
gradients
(
cross_entropy_loss
,
gen_vars
)
gen_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
gen_grads
,
FLAGS
.
grad_clipping
)
gen_pretrain_op
=
optimizer
.
apply_gradients
(
zip
(
gen_grads_clipped
,
gen_vars
),
global_step
=
global_step
)
return
gen_pretrain_op
def
create_gen_train_op
(
hparams
,
learning_rate
,
gen_loss
,
global_step
,
mode
):
"""Create Generator train op."""
del
hparams
with
tf
.
name_scope
(
'train_generator'
):
if
FLAGS
.
generator_optimizer
==
'sgd'
:
gen_optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
elif
FLAGS
.
generator_optimizer
==
'adam'
:
gen_optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
)
else
:
raise
NotImplementedError
gen_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'gen'
)
]
print
(
'Optimizing Generator vars.'
)
for
v
in
gen_vars
:
print
(
v
)
if
mode
==
'MINIMIZE'
:
gen_grads
=
tf
.
gradients
(
gen_loss
,
gen_vars
)
elif
mode
==
'MAXIMIZE'
:
gen_grads
=
tf
.
gradients
(
-
gen_loss
,
gen_vars
)
else
:
raise
ValueError
(
"Must be one of 'MINIMIZE' or 'MAXIMIZE'"
)
gen_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
gen_grads
,
FLAGS
.
grad_clipping
)
gen_train_op
=
gen_optimizer
.
apply_gradients
(
zip
(
gen_grads_clipped
,
gen_vars
),
global_step
=
global_step
)
return
gen_train_op
,
gen_grads_clipped
,
gen_vars
def
create_reinforce_gen_train_op
(
hparams
,
learning_rate
,
final_gen_reward
,
averages_op
,
global_step
):
"""Create the Generator train_op when using REINFORCE.
Args:
hparams: MaskGAN hyperparameters.
learning_rate: tf.Variable scalar learning rate.
final_gen_objective: Scalar final REINFORCE objective for the sequence.
averages_op: ExponentialMovingAverage apply average op to
maintain the baseline.
global_step: global_step tf.Variable.
Returns:
gen_train_op: Generator training op.
"""
del
hparams
with
tf
.
name_scope
(
'train_generator'
):
if
FLAGS
.
generator_optimizer
==
'sgd'
:
gen_optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
elif
FLAGS
.
generator_optimizer
==
'adam'
:
gen_optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
)
else
:
raise
NotImplementedError
gen_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'gen'
)
]
print
(
'
\n
Optimizing Generator vars:'
)
for
v
in
gen_vars
:
print
(
v
)
# Maximize reward.
gen_grads
=
tf
.
gradients
(
-
final_gen_reward
,
gen_vars
)
gen_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
gen_grads
,
FLAGS
.
grad_clipping
)
maximize_op
=
gen_optimizer
.
apply_gradients
(
zip
(
gen_grads_clipped
,
gen_vars
),
global_step
=
global_step
)
# Group maintain averages op.
if
averages_op
:
gen_train_op
=
tf
.
group
(
maximize_op
,
averages_op
)
else
:
gen_train_op
=
maximize_op
return
[
gen_train_op
,
gen_grads
,
gen_vars
]
def
create_dis_train_op
(
hparams
,
dis_loss
,
global_step
):
"""Create Discriminator train op."""
with
tf
.
name_scope
(
'train_discriminator'
):
dis_optimizer
=
tf
.
train
.
AdamOptimizer
(
hparams
.
dis_learning_rate
)
dis_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'dis'
)
]
if
FLAGS
.
dis_update_share_embedding
and
FLAGS
.
dis_share_embedding
:
shared_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/embedding'
][
0
]
dis_vars
.
append
(
shared_embedding
)
print
(
'
\n
Optimizing Discriminator vars:'
)
for
v
in
dis_vars
:
print
(
v
)
dis_grads
=
tf
.
gradients
(
dis_loss
,
dis_vars
)
dis_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
dis_grads
,
FLAGS
.
grad_clipping
)
dis_train_op
=
dis_optimizer
.
apply_gradients
(
zip
(
dis_grads_clipped
,
dis_vars
),
global_step
=
global_step
)
return
dis_train_op
,
dis_grads_clipped
,
dis_vars
def
create_critic_train_op
(
hparams
,
critic_loss
,
global_step
):
"""Create Discriminator train op."""
with
tf
.
name_scope
(
'train_critic'
):
critic_optimizer
=
tf
.
train
.
AdamOptimizer
(
hparams
.
critic_learning_rate
)
output_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'critic'
)
]
if
FLAGS
.
critic_update_dis_vars
:
if
FLAGS
.
discriminator_model
==
'bidirectional_vd'
:
critic_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'dis/rnn'
)
]
elif
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
critic_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'dis/decoder/rnn/multi_rnn_cell'
)
]
critic_vars
.
extend
(
output_vars
)
else
:
critic_vars
=
output_vars
print
(
'
\n
Optimizing Critic vars:'
)
for
v
in
critic_vars
:
print
(
v
)
critic_grads
=
tf
.
gradients
(
critic_loss
,
critic_vars
)
critic_grads_clipped
,
_
=
tf
.
clip_by_global_norm
(
critic_grads
,
FLAGS
.
grad_clipping
)
critic_train_op
=
critic_optimizer
.
apply_gradients
(
zip
(
critic_grads_clipped
,
critic_vars
),
global_step
=
global_step
)
return
critic_train_op
,
critic_grads_clipped
,
critic_vars
research/maskgan/model_utils/model_utils.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model utilities."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
numpy
as
np
import
tensorflow
as
tf
from
model_utils
import
variable_mapping
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
generate_mask
():
"""Generate the mask to be fed into the model."""
if
FLAGS
.
mask_strategy
==
'random'
:
p
=
np
.
random
.
choice
(
[
True
,
False
],
size
=
[
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
],
p
=
[
FLAGS
.
is_present_rate
,
1.
-
FLAGS
.
is_present_rate
])
elif
FLAGS
.
mask_strategy
==
'contiguous'
:
masked_length
=
int
((
1
-
FLAGS
.
is_present_rate
)
*
FLAGS
.
sequence_length
)
-
1
# Determine location to start masking.
start_mask
=
np
.
random
.
randint
(
1
,
FLAGS
.
sequence_length
-
masked_length
+
1
,
size
=
FLAGS
.
batch_size
)
p
=
np
.
full
([
FLAGS
.
batch_size
,
FLAGS
.
sequence_length
],
True
,
dtype
=
bool
)
# Create contiguous masked section to be False.
for
i
,
index
in
enumerate
(
start_mask
):
p
[
i
,
index
:
index
+
masked_length
]
=
False
else
:
raise
NotImplementedError
return
p
def
assign_percent_real
(
session
,
percent_real_update
,
new_rate
,
current_rate
):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
percent_real_update: tf.assign operation.
new_rate: tf.placeholder for the new rate.
current_rate: Percent of tokens that are currently real. Fake tokens
are the ones being imputed by the Generator.
"""
session
.
run
(
percent_real_update
,
feed_dict
=
{
new_rate
:
current_rate
})
def
assign_learning_rate
(
session
,
lr_update
,
lr_placeholder
,
new_lr
):
"""Run assign operation where the we load the current_rate of percent
real into a Tensorflow variable.
Args:
session: Current tf.Session.
lr_update: tf.assign operation.
lr_placeholder: tf.placeholder for the new learning rate.
new_lr: New learning rate to use.
"""
session
.
run
(
lr_update
,
feed_dict
=
{
lr_placeholder
:
new_lr
})
def
clip_weights
(
variables
,
c_lower
,
c_upper
):
"""Clip a list of weights to be within a certain range.
Args:
variables: List of tf.Variable weights.
c_lower: Lower bound for weights.
c_upper: Upper bound for weights.
"""
clip_ops
=
[]
for
var
in
variables
:
clipped_var
=
tf
.
clip_by_value
(
var
,
c_lower
,
c_upper
)
clip_ops
.
append
(
tf
.
assign
(
var
,
clipped_var
))
return
tf
.
group
(
*
clip_ops
)
def
retrieve_init_savers
(
hparams
):
"""Retrieve a dictionary of all the initial savers for the models.
Args:
hparams: MaskGAN hyperparameters.
"""
## Dictionary of init savers.
init_savers
=
{}
## Load Generator weights from MaskGAN checkpoint.
if
FLAGS
.
maskgan_ckpt
:
gen_vars
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
.
startswith
(
'gen'
)
]
init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_vars
)
init_savers
[
'init_saver'
]
=
init_saver
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
dis_variable_maps
=
variable_mapping
.
dis_seq2seq_vd
(
hparams
)
dis_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_variable_maps
)
init_savers
[
'dis_init_saver'
]
=
dis_init_saver
## Load weights from language model checkpoint.
if
FLAGS
.
language_model_ckpt_dir
:
if
FLAGS
.
maskgan_ckpt
is
None
:
## Generator Variables/Savers.
if
FLAGS
.
generator_model
==
'rnn_nas'
:
gen_variable_maps
=
variable_mapping
.
rnn_nas
(
hparams
,
model
=
'gen'
)
gen_init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_variable_maps
)
init_savers
[
'gen_init_saver'
]
=
gen_init_saver
elif
FLAGS
.
generator_model
==
'seq2seq_nas'
:
# Encoder.
gen_encoder_variable_maps
=
variable_mapping
.
gen_encoder_seq2seq_nas
(
hparams
)
gen_encoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_encoder_variable_maps
)
# Decoder.
gen_decoder_variable_maps
=
variable_mapping
.
gen_decoder_seq2seq_nas
(
hparams
)
gen_decoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_decoder_variable_maps
)
init_savers
[
'gen_encoder_init_saver'
]
=
gen_encoder_init_saver
init_savers
[
'gen_decoder_init_saver'
]
=
gen_decoder_init_saver
# seq2seq_vd derived from the same code base as seq2seq_zaremba.
elif
(
FLAGS
.
generator_model
==
'seq2seq_zaremba'
or
FLAGS
.
generator_model
==
'seq2seq_vd'
):
# Encoder.
gen_encoder_variable_maps
=
variable_mapping
.
gen_encoder_seq2seq
(
hparams
)
gen_encoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_encoder_variable_maps
)
# Decoder.
gen_decoder_variable_maps
=
variable_mapping
.
gen_decoder_seq2seq
(
hparams
)
gen_decoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
gen_decoder_variable_maps
)
init_savers
[
'gen_encoder_init_saver'
]
=
gen_encoder_init_saver
init_savers
[
'gen_decoder_init_saver'
]
=
gen_decoder_init_saver
else
:
raise
NotImplementedError
## Discriminator Variables/Savers.
if
FLAGS
.
discriminator_model
==
'rnn_nas'
:
dis_variable_maps
=
variable_mapping
.
rnn_nas
(
hparams
,
model
=
'dis'
)
dis_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_variable_maps
)
init_savers
[
'dis_init_saver'
]
=
dis_init_saver
# rnn_vd derived from the same code base as rnn_zaremba.
elif
(
FLAGS
.
discriminator_model
==
'rnn_zaremba'
or
FLAGS
.
discriminator_model
==
'rnn_vd'
):
dis_variable_maps
=
variable_mapping
.
rnn_zaremba
(
hparams
,
model
=
'dis'
)
dis_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_variable_maps
)
init_savers
[
'dis_init_saver'
]
=
dis_init_saver
elif
(
FLAGS
.
discriminator_model
==
'bidirectional_zaremba'
or
FLAGS
.
discriminator_model
==
'bidirectional_vd'
):
dis_fwd_variable_maps
=
variable_mapping
.
dis_fwd_bidirectional
(
hparams
)
dis_bwd_variable_maps
=
variable_mapping
.
dis_bwd_bidirectional
(
hparams
)
# Savers for the forward/backward Discriminator components.
dis_fwd_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_fwd_variable_maps
)
dis_bwd_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_bwd_variable_maps
)
init_savers
[
'dis_fwd_init_saver'
]
=
dis_fwd_init_saver
init_savers
[
'dis_bwd_init_saver'
]
=
dis_bwd_init_saver
elif
FLAGS
.
discriminator_model
==
'cnn'
:
dis_variable_maps
=
variable_mapping
.
cnn
()
dis_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_variable_maps
)
init_savers
[
'dis_init_saver'
]
=
dis_init_saver
elif
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
# Encoder.
dis_encoder_variable_maps
=
variable_mapping
.
dis_encoder_seq2seq
(
hparams
)
dis_encoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_encoder_variable_maps
)
# Decoder.
dis_decoder_variable_maps
=
variable_mapping
.
dis_decoder_seq2seq
(
hparams
)
dis_decoder_init_saver
=
tf
.
train
.
Saver
(
var_list
=
dis_decoder_variable_maps
)
init_savers
[
'dis_encoder_init_saver'
]
=
dis_encoder_init_saver
init_savers
[
'dis_decoder_init_saver'
]
=
dis_decoder_init_saver
return
init_savers
def
init_fn
(
init_savers
,
sess
):
"""The init_fn to be passed to the Supervisor.
Args:
init_savers: Dictionary of init_savers. 'init_saver_name': init_saver.
sess: tf.Session.
"""
## Load Generator weights from MaskGAN checkpoint.
if
FLAGS
.
maskgan_ckpt
:
print
(
'Restoring Generator from %s.'
%
FLAGS
.
maskgan_ckpt
)
tf
.
logging
.
info
(
'Restoring Generator from %s.'
%
FLAGS
.
maskgan_ckpt
)
print
(
'Asserting Generator is a seq2seq-variant.'
)
tf
.
logging
.
info
(
'Asserting Generator is a seq2seq-variant.'
)
assert
FLAGS
.
generator_model
.
startswith
(
'seq2seq'
)
init_saver
=
init_savers
[
'init_saver'
]
init_saver
.
restore
(
sess
,
FLAGS
.
maskgan_ckpt
)
## Load the Discriminator weights from the MaskGAN checkpoint if
# the weights are compatible.
if
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
print
(
'Restoring Discriminator from %s.'
%
FLAGS
.
maskgan_ckpt
)
tf
.
logging
.
info
(
'Restoring Discriminator from %s.'
%
FLAGS
.
maskgan_ckpt
)
dis_init_saver
=
init_savers
[
'dis_init_saver'
]
dis_init_saver
.
restore
(
sess
,
FLAGS
.
maskgan_ckpt
)
## Load weights from language model checkpoint.
if
FLAGS
.
language_model_ckpt_dir
:
if
FLAGS
.
maskgan_ckpt
is
None
:
## Generator Models.
if
FLAGS
.
generator_model
==
'rnn_nas'
:
load_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir
)
print
(
'Restoring Generator from %s.'
%
load_ckpt
)
tf
.
logging
.
info
(
'Restoring Generator from %s.'
%
load_ckpt
)
gen_init_saver
=
init_savers
[
'gen_init_saver'
]
gen_init_saver
.
restore
(
sess
,
load_ckpt
)
elif
FLAGS
.
generator_model
.
startswith
(
'seq2seq'
):
load_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir
)
print
(
'Restoring Generator from %s.'
%
load_ckpt
)
tf
.
logging
.
info
(
'Restoring Generator from %s.'
%
load_ckpt
)
gen_encoder_init_saver
=
init_savers
[
'gen_encoder_init_saver'
]
gen_decoder_init_saver
=
init_savers
[
'gen_decoder_init_saver'
]
gen_encoder_init_saver
.
restore
(
sess
,
load_ckpt
)
gen_decoder_init_saver
.
restore
(
sess
,
load_ckpt
)
## Discriminator Models.
if
(
FLAGS
.
discriminator_model
==
'rnn_nas'
or
FLAGS
.
discriminator_model
==
'rnn_zaremba'
or
FLAGS
.
discriminator_model
==
'rnn_vd'
or
FLAGS
.
discriminator_model
==
'cnn'
):
load_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir
)
print
(
'Restoring Discriminator from %s.'
%
load_ckpt
)
tf
.
logging
.
info
(
'Restoring Discriminator from %s.'
%
load_ckpt
)
dis_init_saver
=
init_savers
[
'dis_init_saver'
]
dis_init_saver
.
restore
(
sess
,
load_ckpt
)
elif
(
FLAGS
.
discriminator_model
==
'bidirectional_zaremba'
or
FLAGS
.
discriminator_model
==
'bidirectional_vd'
):
assert
FLAGS
.
language_model_ckpt_dir_reversed
is
not
None
,
(
'Need a reversed directory to fill in the backward components.'
)
load_fwd_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir
)
load_bwd_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir_reversed
)
print
(
'Restoring Discriminator from %s and %s.'
%
(
load_fwd_ckpt
,
load_bwd_ckpt
))
tf
.
logging
.
info
(
'Restoring Discriminator from %s and %s.'
%
(
load_fwd_ckpt
,
load_bwd_ckpt
))
dis_fwd_init_saver
=
init_savers
[
'dis_fwd_init_saver'
]
dis_bwd_init_saver
=
init_savers
[
'dis_bwd_init_saver'
]
dis_fwd_init_saver
.
restore
(
sess
,
load_fwd_ckpt
)
dis_bwd_init_saver
.
restore
(
sess
,
load_bwd_ckpt
)
elif
FLAGS
.
discriminator_model
==
'seq2seq_vd'
:
load_ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
language_model_ckpt_dir
)
print
(
'Restoring Discriminator from %s.'
%
load_ckpt
)
tf
.
logging
.
info
(
'Restoring Discriminator from %s.'
%
load_ckpt
)
dis_encoder_init_saver
=
init_savers
[
'dis_encoder_init_saver'
]
dis_decoder_init_saver
=
init_savers
[
'dis_decoder_init_saver'
]
dis_encoder_init_saver
.
restore
(
sess
,
load_ckpt
)
dis_decoder_init_saver
.
restore
(
sess
,
load_ckpt
)
else
:
return
research/maskgan/model_utils/n_gram.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""We calculate n-Grams from the training text. We will use this as an
evaluation metric."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
hash_function
(
input_tuple
):
"""Hash function for a tuple."""
return
hash
(
input_tuple
)
def
find_all_ngrams
(
dataset
,
n
):
"""Generate a list of all ngrams."""
return
zip
(
*
[
dataset
[
i
:]
for
i
in
xrange
(
n
)])
def
construct_ngrams_dict
(
ngrams_list
):
"""Construct a ngram dictionary which maps an ngram tuple to the number
of times it appears in the text."""
counts
=
{}
for
t
in
ngrams_list
:
key
=
hash_function
(
t
)
if
key
in
counts
:
counts
[
key
]
+=
1
else
:
counts
[
key
]
=
1
return
counts
def
percent_unique_ngrams_in_train
(
train_ngrams_dict
,
gen_ngrams_dict
):
"""Compute the percent of ngrams generated by the model that are
present in the training text and are unique."""
# *Total* number of n-grams produced by the generator.
total_ngrams_produced
=
0
for
_
,
value
in
gen_ngrams_dict
.
iteritems
():
total_ngrams_produced
+=
value
# The unique ngrams in the training set.
unique_ngrams_in_train
=
0.
for
key
,
_
in
gen_ngrams_dict
.
iteritems
():
if
key
in
train_ngrams_dict
:
unique_ngrams_in_train
+=
1
return
float
(
unique_ngrams_in_train
)
/
float
(
total_ngrams_produced
)
research/maskgan/model_utils/variable_mapping.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
# Dependency imports
import
tensorflow
as
tf
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
rnn_nas
(
hparams
,
model
):
assert
model
==
'gen'
or
model
==
'dis'
# This logic is only valid for rnn_zaremba
if
model
==
'gen'
:
assert
FLAGS
.
generator_model
==
'rnn_nas'
assert
hparams
.
gen_num_layers
==
2
if
model
==
'dis'
:
assert
FLAGS
.
discriminator_model
==
'rnn_nas'
assert
hparams
.
dis_num_layers
==
2
# Output variables only for the Generator. Discriminator output biases
# will begin randomly initialized.
if
model
==
'gen'
:
softmax_b
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/rnn/softmax_b'
][
0
]
# Common elements to Generator and Discriminator.
embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/embedding'
][
0
]
lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][
0
]
lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][
0
]
lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][
0
]
lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][
0
]
# Dictionary mapping.
if
model
==
'gen'
:
variable_mapping
=
{
'Model/embeddings/input_embedding'
:
embedding
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
:
lstm_w_0
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
:
lstm_b_0
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
:
lstm_w_1
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
:
lstm_b_1
,
'Model/softmax_b'
:
softmax_b
}
else
:
variable_mapping
=
{
'Model/embeddings/input_embedding'
:
embedding
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
:
lstm_w_0
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
:
lstm_b_0
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
:
lstm_w_1
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
:
lstm_b_1
}
return
variable_mapping
def
cnn
():
"""Variable mapping for the CNN embedding.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var.
"""
# This logic is only valid for cnn
assert
FLAGS
.
discriminator_model
==
'cnn'
# Retrieve CNN embedding.
embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/embedding'
][
0
]
# Variable mapping.
variable_mapping
=
{
'Model/embedding'
:
embedding
}
return
variable_mapping
def
rnn_zaremba
(
hparams
,
model
):
"""Returns the PTB Variable name to MaskGAN Variable dictionary mapping. This
is a highly restrictive function just for testing. This will need to be
generalized.
Args:
hparams: Hyperparameters for the MaskGAN.
model: Model type, one of ['gen', 'dis'].
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_var.
"""
assert
model
==
'gen'
or
model
==
'dis'
# This logic is only valid for rnn_zaremba
if
model
==
'gen'
:
assert
FLAGS
.
generator_model
==
'rnn_zaremba'
assert
hparams
.
gen_num_layers
==
2
if
model
==
'dis'
:
assert
(
FLAGS
.
discriminator_model
==
'rnn_zaremba'
or
FLAGS
.
discriminator_model
==
'rnn_vd'
)
assert
hparams
.
dis_num_layers
==
2
# Output variables only for the Generator. Discriminator output weights
# and biases will begin randomly initialized.
if
model
==
'gen'
:
softmax_w
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/rnn/softmax_w'
][
0
]
softmax_b
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/rnn/softmax_b'
][
0
]
# Common elements to Generator and Discriminator.
if
not
FLAGS
.
dis_share_embedding
or
model
!=
'dis'
:
embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/embedding'
][
0
]
lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
str
(
model
)
+
'/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
# Dictionary mapping.
if
model
==
'gen'
:
variable_mapping
=
{
'Model/embedding'
:
embedding
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
lstm_b_1
,
'Model/softmax_w'
:
softmax_w
,
'Model/softmax_b'
:
softmax_b
}
else
:
if
FLAGS
.
dis_share_embedding
:
variable_mapping
=
{
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
lstm_b_1
}
else
:
variable_mapping
=
{
'Model/embedding'
:
embedding
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
lstm_b_1
}
return
variable_mapping
def
gen_encoder_seq2seq_nas
(
hparams
):
"""Returns the NAS Variable name to MaskGAN Variable
dictionary mapping. This is a highly restrictive function just for testing.
This is for the *unidirecitional* seq2seq_nas encoder.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert
FLAGS
.
generator_model
==
'seq2seq_nas'
assert
hparams
.
gen_num_layers
==
2
## Encoder forward variables.
if
not
FLAGS
.
seq2seq_share_embedding
:
encoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/embedding'
][
0
]
encoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][
0
]
encoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][
0
]
encoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][
0
]
encoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][
0
]
if
not
FLAGS
.
seq2seq_share_embedding
:
variable_mapping
=
{
'Model/embeddings/input_embedding'
:
encoder_embedding
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
:
encoder_lstm_w_0
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
:
encoder_lstm_b_0
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
:
encoder_lstm_w_1
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
:
encoder_lstm_b_1
}
else
:
variable_mapping
=
{
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
:
encoder_lstm_w_0
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
:
encoder_lstm_b_0
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
:
encoder_lstm_w_1
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
:
encoder_lstm_b_1
}
return
variable_mapping
def
gen_decoder_seq2seq_nas
(
hparams
):
assert
FLAGS
.
generator_model
==
'seq2seq_nas'
assert
hparams
.
gen_num_layers
==
2
decoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/embedding'
][
0
]
decoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
][
0
]
decoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
][
0
]
decoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
][
0
]
decoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
][
0
]
decoder_softmax_b
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/softmax_b'
][
0
]
variable_mapping
=
{
'Model/embeddings/input_embedding'
:
decoder_embedding
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_h_mat'
:
decoder_lstm_w_0
,
'Model/RNN/GenericMultiRNNCell/Cell0/Alien/rnn_builder/big_inputs_mat'
:
decoder_lstm_b_0
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_h_mat'
:
decoder_lstm_w_1
,
'Model/RNN/GenericMultiRNNCell/Cell1/Alien/rnn_builder/big_inputs_mat'
:
decoder_lstm_b_1
,
'Model/softmax_b'
:
decoder_softmax_b
}
return
variable_mapping
def
gen_encoder_seq2seq
(
hparams
):
"""Returns the PTB Variable name to MaskGAN Variable
dictionary mapping. This is a highly restrictive function just for testing.
This is foe the *unidirecitional* seq2seq_zaremba encoder.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert
(
FLAGS
.
generator_model
==
'seq2seq_zaremba'
or
FLAGS
.
generator_model
==
'seq2seq_vd'
)
assert
hparams
.
gen_num_layers
==
2
## Encoder forward variables.
if
not
FLAGS
.
seq2seq_share_embedding
:
encoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/embedding'
][
0
]
encoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
encoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
if
FLAGS
.
data_set
==
'ptb'
:
model_str
=
'Model'
else
:
model_str
=
'model'
if
not
FLAGS
.
seq2seq_share_embedding
:
variable_mapping
=
{
str
(
model_str
)
+
'/embedding'
:
encoder_embedding
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
encoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
encoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
encoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
encoder_lstm_b_1
}
else
:
variable_mapping
=
{
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
encoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
encoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
encoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
encoder_lstm_b_1
}
return
variable_mapping
def
gen_decoder_seq2seq
(
hparams
):
assert
(
FLAGS
.
generator_model
==
'seq2seq_zaremba'
or
FLAGS
.
generator_model
==
'seq2seq_vd'
)
assert
hparams
.
gen_num_layers
==
2
decoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/embedding'
][
0
]
decoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
decoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
decoder_softmax_b
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'gen/decoder/rnn/softmax_b'
][
0
]
if
FLAGS
.
data_set
==
'ptb'
:
model_str
=
'Model'
else
:
model_str
=
'model'
variable_mapping
=
{
str
(
model_str
)
+
'/embedding'
:
decoder_embedding
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
decoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
decoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
decoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
decoder_lstm_b_1
,
str
(
model_str
)
+
'/softmax_b'
:
decoder_softmax_b
}
return
variable_mapping
def
dis_fwd_bidirectional
(
hparams
):
"""Returns the *forward* PTB Variable name to MaskGAN Variable dictionary
mapping. This is a highly restrictive function just for testing. This is for
the bidirectional_zaremba discriminator.
Args:
FLAGS: Flags for the model.
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert
(
FLAGS
.
discriminator_model
==
'bidirectional_zaremba'
or
FLAGS
.
discriminator_model
==
'bidirectional_vd'
)
assert
hparams
.
dis_num_layers
==
2
# Forward Discriminator Elements.
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/embedding'
][
0
]
fw_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
fw_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
fw_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
fw_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
if
FLAGS
.
dis_share_embedding
:
variable_mapping
=
{
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
fw_lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
fw_lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
fw_lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
fw_lstm_b_1
}
else
:
variable_mapping
=
{
'Model/embedding'
:
embedding
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
fw_lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
fw_lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
fw_lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
fw_lstm_b_1
}
return
variable_mapping
def
dis_bwd_bidirectional
(
hparams
):
"""Returns the *backward* PTB Variable name to MaskGAN Variable dictionary
mapping. This is a highly restrictive function just for testing. This is for
the bidirectional_zaremba discriminator.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert
(
FLAGS
.
discriminator_model
==
'bidirectional_zaremba'
or
FLAGS
.
discriminator_model
==
'bidirectional_vd'
)
assert
hparams
.
dis_num_layers
==
2
# Backward Discriminator Elements.
bw_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
bw_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
bw_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
bw_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
variable_mapping
=
{
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
bw_lstm_w_0
,
'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
bw_lstm_b_0
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
bw_lstm_w_1
,
'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
bw_lstm_b_1
}
return
variable_mapping
def
dis_encoder_seq2seq
(
hparams
):
"""Returns the PTB Variable name to MaskGAN Variable
dictionary mapping.
Args:
hparams: Hyperparameters for the MaskGAN.
Returns:
variable_mapping: Dictionary with Key: ckpt_name, Value: model_varself.
"""
assert
FLAGS
.
discriminator_model
==
'seq2seq_vd'
assert
hparams
.
dis_num_layers
==
2
## Encoder forward variables.
encoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
encoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
if
FLAGS
.
data_set
==
'ptb'
:
model_str
=
'Model'
else
:
model_str
=
'model'
variable_mapping
=
{
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
encoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
encoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
encoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
encoder_lstm_b_1
}
return
variable_mapping
def
dis_decoder_seq2seq
(
hparams
):
assert
FLAGS
.
discriminator_model
==
'seq2seq_vd'
assert
hparams
.
dis_num_layers
==
2
if
not
FLAGS
.
dis_share_embedding
:
decoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/embedding'
][
0
]
decoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
decoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
if
FLAGS
.
data_set
==
'ptb'
:
model_str
=
'Model'
else
:
model_str
=
'model'
if
not
FLAGS
.
dis_share_embedding
:
variable_mapping
=
{
str
(
model_str
)
+
'/embedding'
:
decoder_embedding
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
decoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
decoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
decoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
decoder_lstm_b_1
}
else
:
variable_mapping
=
{
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
decoder_lstm_w_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
decoder_lstm_b_0
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
decoder_lstm_w_1
,
str
(
model_str
)
+
'/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
decoder_lstm_b_1
,
}
return
variable_mapping
def
dis_seq2seq_vd
(
hparams
):
assert
FLAGS
.
discriminator_model
==
'seq2seq_vd'
assert
hparams
.
dis_num_layers
==
2
if
not
FLAGS
.
dis_share_embedding
:
decoder_embedding
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/embedding'
][
0
]
## Encoder variables.
encoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
encoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
encoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
## Attention.
if
FLAGS
.
attention_option
is
not
None
:
decoder_attention_keys
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/attention_keys/weights'
][
0
]
decoder_attention_construct_weights
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/attention_construct/weights'
][
0
]
## Decoder.
decoder_lstm_w_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_0
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
][
0
]
decoder_lstm_w_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
][
0
]
decoder_lstm_b_1
=
[
v
for
v
in
tf
.
trainable_variables
()
if
v
.
op
.
name
==
'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
][
0
]
# Standard variable mappings.
variable_mapping
=
{
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
encoder_lstm_w_0
,
'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
encoder_lstm_b_0
,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
encoder_lstm_w_1
,
'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
encoder_lstm_b_1
,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights'
:
decoder_lstm_w_0
,
'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases'
:
decoder_lstm_b_0
,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights'
:
decoder_lstm_w_1
,
'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases'
:
decoder_lstm_b_1
}
# Optional variable mappings.
if
not
FLAGS
.
dis_share_embedding
:
variable_mapping
[
'gen/decoder/rnn/embedding'
]
=
decoder_embedding
if
FLAGS
.
attention_option
is
not
None
:
variable_mapping
[
'gen/decoder/attention_keys/weights'
]
=
decoder_attention_keys
variable_mapping
[
'gen/decoder/rnn/attention_construct/weights'
]
=
decoder_attention_construct_weights
return
variable_mapping
research/maskgan/models/__init__.py
0 → 100644
View file @
7d16fc45
research/maskgan/models/attention_utils.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention-based decoder functions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.framework
import
function
__all__
=
[
"prepare_attention"
,
"attention_decoder_fn_train"
,
"attention_decoder_fn_inference"
]
def
attention_decoder_fn_train
(
encoder_state
,
attention_keys
,
attention_values
,
attention_score_fn
,
attention_construct_fn
,
name
=
None
):
"""Attentional decoder function for `dynamic_rnn_decoder` during training.
The `attention_decoder_fn_train` is a training function for an
attention-based sequence-to-sequence model. It should be used when
`dynamic_rnn_decoder` is in the training mode.
The `attention_decoder_fn_train` is called with a set of the user arguments
and returns the `decoder_fn`, which can be passed to the
`dynamic_rnn_decoder`, such that
```
dynamic_fn_train = attention_decoder_fn_train(encoder_state)
outputs_train, state_train = dynamic_rnn_decoder(
decoder_fn=dynamic_fn_train, ...)
```
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
Args:
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
name: (default: `None`) NameScope for the decoder function;
defaults to "simple_decoder_fn_train"
Returns:
A decoder function with the required interface of `dynamic_rnn_decoder`
intended for training.
"""
with
tf
.
name_scope
(
name
,
"attention_decoder_fn_train"
,
[
encoder_state
,
attention_keys
,
attention_values
,
attention_score_fn
,
attention_construct_fn
]):
pass
def
decoder_fn
(
time
,
cell_state
,
cell_input
,
cell_output
,
context_state
):
"""Decoder function used in the `dynamic_rnn_decoder` for training.
Args:
time: positive integer constant reflecting the current timestep.
cell_state: state of RNNCell.
cell_input: input provided by `dynamic_rnn_decoder`.
cell_output: output of RNNCell.
context_state: context state provided by `dynamic_rnn_decoder`.
Returns:
A tuple (done, next state, next input, emit output, next context state)
where:
done: `None`, which is used by the `dynamic_rnn_decoder` to indicate
that `sequence_lengths` in `dynamic_rnn_decoder` should be used.
next state: `cell_state`, this decoder function does not modify the
given state.
next input: `cell_input`, this decoder function does not modify the
given input. The input could be modified when applying e.g. attention.
emit output: `cell_output`, this decoder function does not modify the
given output.
next context state: `context_state`, this decoder function does not
modify the given context state. The context state could be modified when
applying e.g. beam search.
"""
with
tf
.
name_scope
(
name
,
"attention_decoder_fn_train"
,
[
time
,
cell_state
,
cell_input
,
cell_output
,
context_state
]):
if
cell_state
is
None
:
# first call, return encoder_state
cell_state
=
encoder_state
# init attention
attention
=
_init_attention
(
encoder_state
)
else
:
# construct attention
attention
=
attention_construct_fn
(
cell_output
,
attention_keys
,
attention_values
)
cell_output
=
attention
# combine cell_input and attention
next_input
=
tf
.
concat
([
cell_input
,
attention
],
1
)
return
(
None
,
cell_state
,
next_input
,
cell_output
,
context_state
)
return
decoder_fn
def
attention_decoder_fn_inference
(
output_fn
,
encoder_state
,
attention_keys
,
attention_values
,
attention_score_fn
,
attention_construct_fn
,
embeddings
,
start_of_sequence_id
,
end_of_sequence_id
,
maximum_length
,
num_decoder_symbols
,
dtype
=
tf
.
int32
,
name
=
None
):
"""Attentional decoder function for `dynamic_rnn_decoder` during inference.
The `attention_decoder_fn_inference` is a simple inference function for a
sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is
in the inference mode.
The `attention_decoder_fn_inference` is called with user arguments
and returns the `decoder_fn`, which can be passed to the
`dynamic_rnn_decoder`, such that
```
dynamic_fn_inference = attention_decoder_fn_inference(...)
outputs_inference, state_inference = dynamic_rnn_decoder(
decoder_fn=dynamic_fn_inference, ...)
```
Further usage can be found in the `kernel_tests/seq2seq_test.py`.
Args:
output_fn: An output function to project your `cell_output` onto class
logits.
An example of an output function;
```
tf.variable_scope("decoder") as varscope
output_fn = lambda x: tf.contrib.layers.linear(x, num_decoder_symbols,
scope=varscope)
outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...)
logits_train = output_fn(outputs_train)
varscope.reuse_variables()
logits_inference, state_inference = seq2seq.dynamic_rnn_decoder(
output_fn=output_fn, ...)
```
If `None` is supplied it will act as an identity function, which
might be wanted when using the RNNCell `OutputProjectionWrapper`.
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
embeddings: The embeddings matrix used for the decoder sized
`[num_decoder_symbols, embedding_size]`.
start_of_sequence_id: The start of sequence ID in the decoder embeddings.
end_of_sequence_id: The end of sequence ID in the decoder embeddings.
maximum_length: The maximum allowed of time steps to decode.
num_decoder_symbols: The number of classes to decode at each time step.
dtype: (default: `tf.int32`) The default data type to use when
handling integer objects.
name: (default: `None`) NameScope for the decoder function;
defaults to "attention_decoder_fn_inference"
Returns:
A decoder function with the required interface of `dynamic_rnn_decoder`
intended for inference.
"""
with
tf
.
name_scope
(
name
,
"attention_decoder_fn_inference"
,
[
output_fn
,
encoder_state
,
attention_keys
,
attention_values
,
attention_score_fn
,
attention_construct_fn
,
embeddings
,
start_of_sequence_id
,
end_of_sequence_id
,
maximum_length
,
num_decoder_symbols
,
dtype
]):
start_of_sequence_id
=
tf
.
convert_to_tensor
(
start_of_sequence_id
,
dtype
)
end_of_sequence_id
=
tf
.
convert_to_tensor
(
end_of_sequence_id
,
dtype
)
maximum_length
=
tf
.
convert_to_tensor
(
maximum_length
,
dtype
)
num_decoder_symbols
=
tf
.
convert_to_tensor
(
num_decoder_symbols
,
dtype
)
encoder_info
=
tf
.
contrib
.
framework
.
nest
.
flatten
(
encoder_state
)[
0
]
batch_size
=
encoder_info
.
get_shape
()[
0
].
value
if
output_fn
is
None
:
output_fn
=
lambda
x
:
x
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
encoder_info
)[
0
]
def
decoder_fn
(
time
,
cell_state
,
cell_input
,
cell_output
,
context_state
):
"""Decoder function used in the `dynamic_rnn_decoder` for inference.
The main difference between this decoder function and the `decoder_fn` in
`attention_decoder_fn_train` is how `next_cell_input` is calculated. In
decoder function we calculate the next input by applying an argmax across
the feature dimension of the output from the decoder. This is a
greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
use beam-search instead.
Args:
time: positive integer constant reflecting the current timestep.
cell_state: state of RNNCell.
cell_input: input provided by `dynamic_rnn_decoder`.
cell_output: output of RNNCell.
context_state: context state provided by `dynamic_rnn_decoder`.
Returns:
A tuple (done, next state, next input, emit output, next context state)
where:
done: A boolean vector to indicate which sentences has reached a
`end_of_sequence_id`. This is used for early stopping by the
`dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
all elements as `true` is returned.
next state: `cell_state`, this decoder function does not modify the
given state.
next input: The embedding from argmax of the `cell_output` is used as
`next_input`.
emit output: If `output_fn is None` the supplied `cell_output` is
returned, else the `output_fn` is used to update the `cell_output`
before calculating `next_input` and returning `cell_output`.
next context state: `context_state`, this decoder function does not
modify the given context state. The context state could be modified when
applying e.g. beam search.
Raises:
ValueError: if cell_input is not None.
"""
with
tf
.
name_scope
(
name
,
"attention_decoder_fn_inference"
,
[
time
,
cell_state
,
cell_input
,
cell_output
,
context_state
]):
if
cell_input
is
not
None
:
raise
ValueError
(
"Expected cell_input to be None, but saw: %s"
%
cell_input
)
if
cell_output
is
None
:
# invariant that this is time == 0
next_input_id
=
tf
.
ones
(
[
batch_size
,
],
dtype
=
dtype
)
*
(
start_of_sequence_id
)
done
=
tf
.
zeros
(
[
batch_size
,
],
dtype
=
tf
.
bool
)
cell_state
=
encoder_state
cell_output
=
tf
.
zeros
([
num_decoder_symbols
],
dtype
=
tf
.
float32
)
cell_input
=
tf
.
gather
(
embeddings
,
next_input_id
)
# init attention
attention
=
_init_attention
(
encoder_state
)
else
:
# construct attention
attention
=
attention_construct_fn
(
cell_output
,
attention_keys
,
attention_values
)
cell_output
=
attention
# argmax decoder
cell_output
=
output_fn
(
cell_output
)
# logits
next_input_id
=
tf
.
cast
(
tf
.
argmax
(
cell_output
,
1
),
dtype
=
dtype
)
done
=
tf
.
equal
(
next_input_id
,
end_of_sequence_id
)
cell_input
=
tf
.
gather
(
embeddings
,
next_input_id
)
# combine cell_input and attention
next_input
=
tf
.
concat
([
cell_input
,
attention
],
1
)
# if time > maxlen, return all true vector
done
=
tf
.
cond
(
tf
.
greater
(
time
,
maximum_length
),
lambda
:
tf
.
ones
([
batch_size
,],
dtype
=
tf
.
bool
),
lambda
:
done
)
return
(
done
,
cell_state
,
next_input
,
cell_output
,
context_state
)
return
decoder_fn
## Helper functions ##
def
prepare_attention
(
attention_states
,
attention_option
,
num_units
,
reuse
=
None
):
"""Prepare keys/values/functions for attention.
Args:
attention_states: hidden states to attend over.
attention_option: how to compute attention, either "luong" or "bahdanau".
num_units: hidden state dimension.
reuse: whether to reuse variable scope.
Returns:
attention_keys: to be compared with target states.
attention_values: to be used to construct context vectors.
attention_score_fn: to compute similarity between key and target states.
attention_construct_fn: to build attention states.
"""
# Prepare attention keys / values from attention_states
with
tf
.
variable_scope
(
"attention_keys"
,
reuse
=
reuse
)
as
scope
:
attention_keys
=
tf
.
contrib
.
layers
.
linear
(
attention_states
,
num_units
,
biases_initializer
=
None
,
scope
=
scope
)
attention_values
=
attention_states
# Attention score function
attention_score_fn
=
_create_attention_score_fn
(
"attention_score"
,
num_units
,
attention_option
,
reuse
)
# Attention construction function
attention_construct_fn
=
_create_attention_construct_fn
(
"attention_construct"
,
num_units
,
attention_score_fn
,
reuse
)
return
(
attention_keys
,
attention_values
,
attention_score_fn
,
attention_construct_fn
)
def
_init_attention
(
encoder_state
):
"""Initialize attention. Handling both LSTM and GRU.
Args:
encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
Returns:
attn: initial zero attention vector.
"""
# Multi- vs single-layer
# TODO(thangluong): is this the best way to check?
if
isinstance
(
encoder_state
,
tuple
):
top_state
=
encoder_state
[
-
1
]
else
:
top_state
=
encoder_state
# LSTM vs GRU
if
isinstance
(
top_state
,
tf
.
contrib
.
rnn
.
LSTMStateTuple
):
attn
=
tf
.
zeros_like
(
top_state
.
h
)
else
:
attn
=
tf
.
zeros_like
(
top_state
)
return
attn
def
_create_attention_construct_fn
(
name
,
num_units
,
attention_score_fn
,
reuse
):
"""Function to compute attention vectors.
Args:
name: to label variables.
num_units: hidden state dimension.
attention_score_fn: to compute similarity between key and target states.
reuse: whether to reuse variable scope.
Returns:
attention_construct_fn: to build attention states.
"""
def
construct_fn
(
attention_query
,
attention_keys
,
attention_values
):
with
tf
.
variable_scope
(
name
,
reuse
=
reuse
)
as
scope
:
context
=
attention_score_fn
(
attention_query
,
attention_keys
,
attention_values
)
concat_input
=
tf
.
concat
([
attention_query
,
context
],
1
)
attention
=
tf
.
contrib
.
layers
.
linear
(
concat_input
,
num_units
,
biases_initializer
=
None
,
scope
=
scope
)
return
attention
return
construct_fn
# keys: [batch_size, attention_length, attn_size]
# query: [batch_size, 1, attn_size]
# return weights [batch_size, attention_length]
@
function
.
Defun
(
func_name
=
"attn_add_fun"
,
noinline
=
True
)
def
_attn_add_fun
(
v
,
keys
,
query
):
return
tf
.
reduce_sum
(
v
*
tf
.
tanh
(
keys
+
query
),
[
2
])
@
function
.
Defun
(
func_name
=
"attn_mul_fun"
,
noinline
=
True
)
def
_attn_mul_fun
(
keys
,
query
):
return
tf
.
reduce_sum
(
keys
*
query
,
[
2
])
def
_create_attention_score_fn
(
name
,
num_units
,
attention_option
,
reuse
,
dtype
=
tf
.
float32
):
"""Different ways to compute attention scores.
Args:
name: to label variables.
num_units: hidden state dimension.
attention_option: how to compute attention, either "luong" or "bahdanau".
"bahdanau": additive (Bahdanau et al., ICLR'2015)
"luong": multiplicative (Luong et al., EMNLP'2015)
reuse: whether to reuse variable scope.
dtype: (default: `tf.float32`) data type to use.
Returns:
attention_score_fn: to compute similarity between key and target states.
"""
with
tf
.
variable_scope
(
name
,
reuse
=
reuse
):
if
attention_option
==
"bahdanau"
:
query_w
=
tf
.
get_variable
(
"attnW"
,
[
num_units
,
num_units
],
dtype
=
dtype
)
score_v
=
tf
.
get_variable
(
"attnV"
,
[
num_units
],
dtype
=
dtype
)
def
attention_score_fn
(
query
,
keys
,
values
):
"""Put attention masks on attention_values using attention_keys and query.
Args:
query: A Tensor of shape [batch_size, num_units].
keys: A Tensor of shape [batch_size, attention_length, num_units].
values: A Tensor of shape [batch_size, attention_length, num_units].
Returns:
context_vector: A Tensor of shape [batch_size, num_units].
Raises:
ValueError: if attention_option is neither "luong" or "bahdanau".
"""
if
attention_option
==
"bahdanau"
:
# transform query
query
=
tf
.
matmul
(
query
,
query_w
)
# reshape query: [batch_size, 1, num_units]
query
=
tf
.
reshape
(
query
,
[
-
1
,
1
,
num_units
])
# attn_fun
scores
=
_attn_add_fun
(
score_v
,
keys
,
query
)
elif
attention_option
==
"luong"
:
# reshape query: [batch_size, 1, num_units]
query
=
tf
.
reshape
(
query
,
[
-
1
,
1
,
num_units
])
# attn_fun
scores
=
_attn_mul_fun
(
keys
,
query
)
else
:
raise
ValueError
(
"Unknown attention option %s!"
%
attention_option
)
# Compute alignment weights
# scores: [batch_size, length]
# alignments: [batch_size, length]
# TODO(thangluong): not normalize over padding positions.
alignments
=
tf
.
nn
.
softmax
(
scores
)
# Now calculate the attention-weighted vector.
alignments
=
tf
.
expand_dims
(
alignments
,
2
)
context_vector
=
tf
.
reduce_sum
(
alignments
*
values
,
[
1
])
context_vector
.
set_shape
([
None
,
num_units
])
return
context_vector
return
attention_score_fn
research/maskgan/models/bidirectional.py
0 → 100644
View file @
7d16fc45
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple bidirectional model definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# ZoneoutWrapper.
from
regularization
import
zoneout
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
discriminator
(
hparams
,
sequence
,
is_training
,
reuse
=
None
):
"""Define the bidirectional Discriminator graph."""
sequence
=
tf
.
cast
(
sequence
,
tf
.
int32
)
if
FLAGS
.
dis_share_embedding
:
assert
hparams
.
dis_rnn_size
==
hparams
.
gen_rnn_size
,
(
'If you wish to share Discriminator/Generator embeddings, they must be'
' same dimension.'
)
with
tf
.
variable_scope
(
'gen/rnn'
,
reuse
=
True
):
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
gen_rnn_size
])
with
tf
.
variable_scope
(
'dis'
,
reuse
=
reuse
):
cell_fwd
=
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
1.0
,
reuse
=
reuse
)
cell_bwd
=
tf
.
contrib
.
rnn
.
LayerNormBasicLSTMCell
(
hparams
.
dis_rnn_size
,
forget_bias
=
1.0
,
reuse
=
reuse
)
if
FLAGS
.
zoneout_drop_prob
>
0.0
:
cell_fwd
=
zoneout
.
ZoneoutWrapper
(
cell_fwd
,
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
cell_bwd
=
zoneout
.
ZoneoutWrapper
(
cell_bwd
,
zoneout_drop_prob
=
FLAGS
.
zoneout_drop_prob
,
is_training
=
is_training
)
state_fwd
=
cell_fwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
state_bwd
=
cell_bwd
.
zero_state
(
FLAGS
.
batch_size
,
tf
.
float32
)
if
not
FLAGS
.
dis_share_embedding
:
embedding
=
tf
.
get_variable
(
'embedding'
,
[
FLAGS
.
vocab_size
,
hparams
.
dis_rnn_size
])
rnn_inputs
=
tf
.
nn
.
embedding_lookup
(
embedding
,
sequence
)
rnn_inputs
=
tf
.
unstack
(
rnn_inputs
,
axis
=
1
)
with
tf
.
variable_scope
(
'rnn'
)
as
vs
:
outputs
,
_
,
_
=
tf
.
contrib
.
rnn
.
static_bidirectional_rnn
(
cell_fwd
,
cell_bwd
,
rnn_inputs
,
state_fwd
,
state_bwd
,
scope
=
vs
)
# Prediction is linear output for Discriminator.
predictions
=
tf
.
contrib
.
layers
.
linear
(
outputs
,
1
,
scope
=
vs
)
predictions
=
tf
.
transpose
(
predictions
,
[
1
,
0
,
2
])
return
tf
.
squeeze
(
predictions
,
axis
=
2
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment