README.md 4.07 KB
Newer Older
Ryan Sepassi's avatar
Ryan Sepassi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Adversarial Text Classification

Code for *Adversarial Training Methods for Semi-Supervised Text Classification*
(https://arxiv.org/abs/1605.07725) and *Semi-Supervised Sequence Learning*
(https://arxiv.org/abs/1511.01432).

## Requirements

* Bazel ([install](https://bazel.build/versions/master/docs/install.html))
* TensorFlow >= v1.1

## End-to-end IMDB Sentiment Classification

### Fetch data

```
$ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \
    -O /tmp/imdb.tar.gz
$ tar -xf /tmp/imdb.tar.gz -C /tmp
```

The directory `/tmp/aclImdb` contains the raw IMDB data.

### Generate vocabulary

```
$ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \
    --output_dir=$IMDB_DATA_DIR \
    --dataset=imdb \
    --imdb_input_dir=/tmp/aclImdb \
    --lowercase=False
```

Vocabulary and frequency files will be generated in `$IMDB_DATA_DIR`.

###  Generate training, validation, and test data

```
$ bazel run data:gen_data -- \
    --output_dir=$IMDB_DATA_DIR \
    --dataset=imdb \
    --imdb_input_dir=/tmp/aclImdb \
    --lowercase=False \
    --label_gain=False
```

`$IMDB_DATA_DIR` contains TFRecords files.

### Pretrain IMDB Language Model

```
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \
    --train_dir=$PRETRAIN_DIR \
    --data_dir=$IMDB_DATA_DIR \
    --vocab_size=86934 \
    --embedding_dims=256 \
    --rnn_cell_size=1024 \
    --num_candidate_samples=1024 \
    --optimizer=adam \
    --batch_size=256 \
    --learning_rate=0.001 \
    --learning_rate_decay_factor=0.9999 \
    --max_steps=100000 \
    --max_grad_norm=1.0 \
    --num_timesteps=400 \
    --keep_prob_emb=0.5 \
    --normalize_embeddings
```

`$PRETRAIN_DIR` contains checkpoints of the pretrained language model.

### Train classifier

Most flags stay the same, save for the removal of candidate sampling and the
addition of `pretrained_model_dir`, from which the classifier will load the
pretrained embedding and LSTM variables, and flags related to adversarial
training and classification.

```
$ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \
    --train_dir=$TRAIN_DIR \
    --pretrained_model_dir=$PRETRAIN_DIR \
    --data_dir=$IMDB_DATA_DIR \
    --vocab_size=86934 \
    --embedding_dims=256 \
    --rnn_cell_size=1024 \
    --cl_num_layers=1 \
    --cl_hidden_size=30 \
    --optimizer=adam \
    --batch_size=64 \
    --learning_rate=0.0005 \
    --learning_rate_decay_factor=0.9998 \
    --max_steps=15000 \
    --max_grad_norm=1.0 \
    --num_timesteps=400 \
    --keep_prob_emb=0.5 \
    --normalize_embeddings \
    --adv_training_method=vat
```

### Evaluate on test data

```
$ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \
    --eval_dir=$EVAL_DIR \
    --checkpoint_dir=$TRAIN_DIR \
    --eval_data=test \
    --run_once \
    --num_examples=25000 \
    --data_dir=$IMDB_DATA_DIR \
    --vocab_size=86934 \
    --embedding_dims=256 \
    --rnn_cell_size=1024 \
    --batch_size=256 \
    --num_timesteps=400 \
    --normalize_embeddings
```

## Code Overview

The main entry points are the binaries listed below. Each training binary builds
a `VatxtModel`, defined in `graphs.py`, which in turn uses graph building blocks
defined in `inputs.py` (defines input data reading and parsing), `layers.py`
(defines core model components), and `adversarial_losses.py` (defines
adversarial training losses). The training loop itself is defined in
`train_utils.py`.

### Binaries

*   Pretraining: `pretrain.py`
*   Classifier Training: `train_classifier.py`
*   Evaluation: `evaluate.py`

### Command-Line Flags

Flags related to distributed training and the training loop itself are defined
in `train_utils.py`.

Flags related to model hyperparameters are defined in `graphs.py`.

Flags related to adversarial training are defined in `adversarial_losses.py`.

Flags particular to each job are defined in the main binary files.

### Data Generation

*   Vocabulary generation: `gen_vocab.py`
*   Data generation: `gen_data.py`

Command-line flags defined in `document_generators.py` control which dataset is
processed and how.

## Contact for Issues

* Ryan Sepassi, @rsepassi