@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
...
@@ -87,7 +86,6 @@ $ bazel run :train_classifier -- \
--rnn_cell_size=1024 \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--cl_hidden_size=30 \
--optimizer=adam \
--batch_size=64 \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
--learning_rate_decay_factor=0.9998 \
...
@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
...
@@ -96,7 +94,8 @@ $ bazel run :train_classifier -- \
--num_timesteps=400 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--normalize_embeddings \
--adv_training_method=vat
--adv_training_method=vat \
--perturb_norm_length=5.0
```
```
### Evaluate on test data
### Evaluate on test data
...
@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
...
@@ -136,21 +135,21 @@ adversarial training losses). The training loop itself is defined in
### Command-Line Flags
### Command-Line Flags
Flags related to distributed training and the training loop itself are defined
Flags related to distributed training and the training loop itself are defined
in `train_utils.py`.
in [`train_utils.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/train_utils.py).
Flags related to model hyperparameters are defined in `graphs.py`.
Flags related to model hyperparameters are defined in [`graphs.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/graphs.py).
Flags related to adversarial training are defined in `adversarial_losses.py`.
Flags related to adversarial training are defined in [`adversarial_losses.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/adversarial_losses.py).
Flags particular to each job are defined in the main binary files.
Flags particular to each job are defined in the main binary files.
* Data generation: [`gen_data.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/gen_data.py)
Command-line flags defined in `document_generators.py` control which dataset is
Command-line flags defined in [`document_generators.py`](https://github.com/tensorflow/models/tree/master/adversarial_text/data/document_generators.py)