README.md 4.62 KB
Newer Older
1
2
3
4
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)

Andrew M. Dai's avatar
Andrew M. Dai committed
5
6
7
8
9
10
11
# MaskGAN: Better Text Generation via Filling in the ______

Code for [*MaskGAN: Better Text Generation via Filling in the
______*](https://arxiv.org/abs/1801.07736) published at ICLR 2018.

## Requirements

12
*   TensorFlow >= v1.5
Andrew M. Dai's avatar
Andrew M. Dai committed
13
14
15
16
17
18
19
20

## Instructions

Warning: The open-source version of this code is still in the process of being
tested. Pretraining may not work correctly.

For training on PTB:

21
22
23
24
1. Follow instructions here ([Tensorflow RNN Language Model Tutorial](https://www.tensorflow.org/tutorials/sequences/recurrent)) to train a language model on PTB dataset.
Copy PTB data downloaded from the above tensorflow RNN tutorial to folder "/tmp/ptb". It should contain following three files: ptb.train.txt, ptb.test.txt, ptb.valid.txt
Make folder /tmp/pretrain-lm and copy checkpoints from above Tensorflow RNN tutorial under this folder.

Andrew M. Dai's avatar
Andrew M. Dai committed
25

26
2. Run MaskGAN in MLE pretraining mode. If step 1 was not run*, set
Andrew M. Dai's avatar
Andrew M. Dai committed
27
`language_model_ckpt_dir` to empty.
Andrew M. Dai's avatar
Andrew M. Dai committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

```bash
python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=20 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --language_model_ckpt_dir=/tmp/pretrain-lm/ \
 --generator_model='seq2seq_vd' \
 --discriminator_model='rnn_zaremba' \
 --is_present_rate=0.5 \
 --summaries_every=10 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy=cross_entropy \
 --seq2seq_share_embedding
```

Andrew M. Dai's avatar
Andrew M. Dai committed
49
3. Run MaskGAN in GAN mode. If step 2 was not run, set `maskgan_ckpt` to empty.
Andrew M. Dai's avatar
Andrew M. Dai committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
```bash
python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=128 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --mask_strategy=contiguous \
 --maskgan_ckpt='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --generator_model='seq2seq_vd' \
 --discriminator_model='seq2seq_vd' \
 --is_present_rate=0.5 \
 --summaries_every=250 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy='reinforce' \
 --seq2seq_share_embedding=true \
 --baseline_method=critic \
 --attention_option=luong
```

4. Generate samples:
```bash
python generate_samples.py \
 --data_dir /tmp/ptb/ \
 --data_set=ptb \
 --batch_size=256 \
 --sequence_length=20 \
 --base_directory /tmp/imdbsample/ \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971" \
 --generator_model=seq2seq_vd \
 --discriminator_model=seq2seq_vd \
 --is_present_rate=0.0 \
 --maskgan_ckpt=/tmp/maskGAN \
 --seq2seq_share_embedding=True \
 --dis_share_embedding=True \
 --attention_option=luong \
 --mask_strategy=contiguous \
 --baseline_method=critic \
 --number_epochs=4
```

94
95
96
97
98
99
100
101
102
103
104
105
106
107

*  While trying to run Step 2, the following error appears:
   NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph    key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original    error:

   Key critic/rnn/biases not found in checkpoint
   [[node save/RestoreV2 (defined at train_mask_gan.py:431) ]]

   This is an issue with seq2seq model because it uses the attention mechanism.
   The issue arises if you saved the model with an earlier version (seq2seq is old) and restore with a recent one (saver.restore got updated).
   The naming convention for LSTM parameters changed, e.g. cell_0/basic_lstm_cell/weights became cell_0/basic_lstm_cell/kernel.
   Which is why you cannot restore them if you try to restore old checkpoints with recent TF.
   The below script will help rename the variables and everything will work as expected.
   https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py

Andrew M. Dai's avatar
Andrew M. Dai committed
108
109
110
111
## Contact for Issues

*   Liam Fedus, @liamb315 <liam.fedus@gmail.com>
*   Andrew M. Dai, @a-dai <adai@google.com>