Unverified Commit 17a5bdc8 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3487 from a-dai/master

Fix some minor API issues.
parents 7d16fc45 9bbbeb44
...@@ -14,10 +14,11 @@ tested. Pretraining may not work correctly. ...@@ -14,10 +14,11 @@ tested. Pretraining may not work correctly.
For training on PTB: For training on PTB:
1. (Optional) Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/. 1. (Optional) Pretrain a LM on PTB and store the checkpoint in `/tmp/pretrain-lm/`.
Instructions WIP. Instructions WIP.
2. (Optional) Run MaskGAN in MLE pretraining mode: 2. (Optional) Run MaskGAN in MLE pretraining mode. If step 1 was not run, set
`language_model_ckpt_dir` to empty.
```bash ```bash
python train_mask_gan.py \ python train_mask_gan.py \
...@@ -39,7 +40,7 @@ python train_mask_gan.py \ ...@@ -39,7 +40,7 @@ python train_mask_gan.py \
--seq2seq_share_embedding --seq2seq_share_embedding
``` ```
3. Run MaskGAN in GAN mode: 3. Run MaskGAN in GAN mode. If step 2 was not run, set `maskgan_ckpt` to empty.
```bash ```bash
python train_mask_gan.py \ python train_mask_gan.py \
--data_dir='/tmp/ptb' \ --data_dir='/tmp/ptb' \
......
...@@ -78,11 +78,10 @@ MODE_TEST = 'TEST' ...@@ -78,11 +78,10 @@ MODE_TEST = 'TEST'
tf.app.flags.DEFINE_enum( tf.app.flags.DEFINE_enum(
'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL], 'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL],
'What this binary will do.') 'What this binary will do.')
tf.app.flags.DEFINE_string('master', 'local', tf.app.flags.DEFINE_string('master', '',
"""Name of the TensorFlow master to use.""") """Name of the TensorFlow master to use.""")
tf.app.flags.DEFINE_string('eval_master', 'local', tf.app.flags.DEFINE_string('eval_master', '',
"""Name prefix of the Tensorflow eval master, """Name prefix of the Tensorflow eval master.""")
or "local".""")
tf.app.flags.DEFINE_integer('task', 0, tf.app.flags.DEFINE_integer('task', 0,
"""Task id of the replica running the training.""") """Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job. tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job.
...@@ -517,7 +516,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): ...@@ -517,7 +516,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
is_chief = FLAGS.task == 0 is_chief = FLAGS.task == 0
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
container_name = '' container_name = ''
with tf.container(container_name): with tf.container(container_name):
# Construct the model. # Construct the model.
...@@ -540,7 +539,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts): ...@@ -540,7 +539,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
# Create the supervisor. It will take care of initialization, # Create the supervisor. It will take care of initialization,
# summaries, checkpoints, and recovery. # summaries, checkpoints, and recovery.
sv = tf.Supervisor( sv = tf.train.Supervisor(
logdir=log_dir, logdir=log_dir,
is_chief=is_chief, is_chief=is_chief,
saver=model.saver, saver=model.saver,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment