Unverified Commit 17a5bdc8 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3487 from a-dai/master

Fix some minor API issues.
parents 7d16fc45 9bbbeb44
......@@ -14,10 +14,11 @@ tested. Pretraining may not work correctly.
For training on PTB:
1. (Optional) Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/.
1. (Optional) Pretrain a LM on PTB and store the checkpoint in `/tmp/pretrain-lm/`.
Instructions WIP.
2. (Optional) Run MaskGAN in MLE pretraining mode:
2. (Optional) Run MaskGAN in MLE pretraining mode. If step 1 was not run, set
`language_model_ckpt_dir` to empty.
```bash
python train_mask_gan.py \
......@@ -39,7 +40,7 @@ python train_mask_gan.py \
--seq2seq_share_embedding
```
3. Run MaskGAN in GAN mode:
3. Run MaskGAN in GAN mode. If step 2 was not run, set `maskgan_ckpt` to empty.
```bash
python train_mask_gan.py \
--data_dir='/tmp/ptb' \
......
......@@ -78,11 +78,10 @@ MODE_TEST = 'TEST'
tf.app.flags.DEFINE_enum(
'mode', 'TRAIN', [MODE_TRAIN, MODE_VALIDATION, MODE_TEST, MODE_TRAIN_EVAL],
'What this binary will do.')
tf.app.flags.DEFINE_string('master', 'local',
tf.app.flags.DEFINE_string('master', '',
"""Name of the TensorFlow master to use.""")
tf.app.flags.DEFINE_string('eval_master', 'local',
"""Name prefix of the Tensorflow eval master,
or "local".""")
tf.app.flags.DEFINE_string('eval_master', '',
"""Name prefix of the Tensorflow eval master.""")
tf.app.flags.DEFINE_integer('task', 0,
"""Task id of the replica running the training.""")
tf.app.flags.DEFINE_integer('ps_tasks', 0, """Number of tasks in the ps job.
......@@ -517,7 +516,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
is_chief = FLAGS.task == 0
with tf.Graph().as_default():
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)):
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
container_name = ''
with tf.container(container_name):
# Construct the model.
......@@ -540,7 +539,7 @@ def train_model(hparams, data, log_dir, log, id_to_word, data_ngram_counts):
# Create the supervisor. It will take care of initialization,
# summaries, checkpoints, and recovery.
sv = tf.Supervisor(
sv = tf.train.Supervisor(
logdir=log_dir,
is_chief=is_chief,
saver=model.saver,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment