train.md 7.38 KB
Newer Older
1
2
# Model Garden NLP Common Training Driver

3
[train.py](https://github.com/tensorflow/models/blob/master/official/nlp/train.py) is the common training driver that supports multiple
4
5
6
7
8
NLP tasks (e.g., pre-training, GLUE and SQuAD fine-tuning etc) and multiple
models (e.g., BERT, ALBERT, MobileBERT etc).

## Experiment Configuration

9
[train.py] is driven by configs defined by the [ExperimentConfig](https://github.com/tensorflow/models/blob/master/official/core/config_definitions.py)
10
including configurations for `task`, `trainer` and `runtime`. The pre-defined
11
12
NLP related [ExperimentConfig](https://github.com/tensorflow/models/blob/master/official/core/config_definitions.py) can be found in
[configs/experiment_configs.py](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiment_configs.py).
13
14
15

## Experiment Registry

16
We use an [experiment registry](https://github.com/tensorflow/models/blob/master/official/core/exp_factory.py) to build a mapping
17
between experiment type to experiment configuration instance. For example,
18
[configs/finetuning_experiments.py](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
registers `bert/sentence_prediction` and `bert/squad` experiments. User can use
`--experiment` FLAG to invoke a registered experiment configuration,
e.g., `--experiment=bert/sentence_prediction`.

## Overriding Configuration via Yaml and FLAGS

The registered experiment configuration can be overridden by one or
multiple Yaml files provided by `--config_file` FLAG. For example:

```shell
--config_file=configs/experiments/glue_mnli_matched.yaml \
--config_file=configs/models/bert_en_uncased_base.yaml
```

In addition, experiment configuration can be further overriden by
`params_override` FLAG. For example:

```shell
 --params_override=task.train_data.input_path=/some/path,task.hub_module_url=/some/tfhub
```

## Run on Cloud TPUs

42
Next, we will describe how to run the [train.py](https://github.com/tensorflow/models/blob/master/official/nlp/train.py) on Cloud TPUs.
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

### Setup
First, you need to create a `tf-nightly` TPU with
[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):

```shell
export TPU_NAME=YOUR_TPU_NAME
ctpu up -name $TPU_NAME --tf-version=nightly --tpu-size=YOUR_TPU_SIZE --project=YOUR_PROJECT
```

and then install Model Garden and required dependencies:

```shell
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=$PYTHONPATH:/path/to/models
pip3 install --user -r official/requirements.txt
```

### Fine-tuning Sentence Classification with BERT from TF-Hub

This example fine-tunes BERT-base from TF-Hub on the the Multi-Genre Natural
Language Inference (MultiNLI) corpus using TPUs.

Firstly, you can prepare the fine-tuning data using
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
[`create_finetuning_data.py`](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) script.
For GLUE tasks, you can (1) download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`, (2) prepare the vocabulary file,
and (3) run the following command:

```shell
export GLUE_DIR=~/glue
export VOCAB_FILE=~/uncased_L-12_H-768_A-12/vocab.txt

export TASK_NAME=MNLI
export OUTPUT_DATA_DIR=gs://some_bucket/datasets
python3 data/create_finetuning_data.py \
 --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
 --vocab_file=${VOCAB_FILE} \
 --train_data_output_path=${OUTPUT_DATA_DIR}/${TASK_NAME}_train.tf_record \
 --eval_data_output_path=${OUTPUT_DATA_DIR}/${TASK_NAME}_eval.tf_record \
 --meta_data_file_path=${OUTPUT_DATA_DIR}/${TASK_NAME}_meta_data \
 --fine_tuning_task_type=classification --max_seq_length=128 \
 --classification_task_name=${TASK_NAME}
```

90
Resulting training and evaluation datasets in `tf_record` format will be later
91
92
passed to [train.py](train.py). We will support to read dataset from
tensorflow_datasets (TFDS) and use tf.text for pre-processing soon.
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

Then you can execute the following commands to start the training and evaluation
job.

```shell
export INPUT_DATA_DIR=gs://some_bucket/datasets
export OUTPUT_DIR=gs://some_bucket/my_output_dir

# See tfhub BERT collection for more tfhub models:
# https://tfhub.dev/google/collections/bert/1
export BERT_HUB_URL=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3

# Override the configurations by FLAGS. Alternatively, you can directly edit
# `configs/experiments/glue_mnli_matched.yaml` to specify corresponding fields.
export PARAMS=task.train_data.input_path=$INPUT_DATA_DIR/mnli_train.tf_record
export PARAMS=$PARAMS,task.validation_data.input_path=$INPUT_DATA_DIR/mnli_eval.tf_record
export PARAMS=$PARAMS,task.hub_module_url=$BERT_HUB_URL
export PARAMS=$PARAMS,runtime.distribution_strategy=tpu

python3 train.py \
 --experiment=bert/sentence_prediction \
 --mode=train_and_eval \
 --model_dir=$OUTPUT_DIR \
 --config_file=configs/experiments/glue_mnli_matched.yaml \
 --tfhub_cache_dir=$OUTPUT_DIR/hub_cache \
 --tpu=${TPU_NAME} \
 --params_override=$PARAMS

```

You can monitor the training progress in the console and find the output
models in `$OUTPUT_DIR`.

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
### Fine-tuning SQuAD with a pre-trained BERT checkpoint

This example fine-tunes a pre-trained BERT checkpoint on the
Stanford Question Answering Dataset (SQuAD) using TPUs.
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation. After downloading
the SQuAD datasets and the [pre-trained BERT checkpoints](https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md),
you can run the following command to prepare the `tf_record` files:

```shell
export SQUAD_DIR=~/squad
export BERT_DIR=~/uncased_L-12_H-768_A-12
export OUTPUT_DATA_DIR=gs://some_bucket/datasets

python3 create_finetuning_data.py \
 --squad_data_file=${SQUAD_DIR}/train-v1.1.json \
 --vocab_file=${BERT_DIR}/vocab.txt \
 --train_data_output_path=${OUTPUT_DATA_DIR}/train.tf_record \
 --meta_data_file_path=${OUTPUT_DATA_DIR}/squad_meta_data \
 --fine_tuning_task_type=squad --max_seq_length=384
```

Note: To create fine-tuning data with SQuAD 2.0, you need to add flag `--version_2_with_negative=True`.

Then, you can start the training and evaluation jobs:

```shell
export SQUAD_DIR=~/squad
export INPUT_DATA_DIR=gs://some_bucket/datasets
export OUTPUT_DIR=gs://some_bucket/my_output_dir

# See the following link for more pre-trained checkpoints:
# https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md
export BERT_DIR=~/uncased_L-12_H-768_A-12

# Override the configurations by FLAGS. Alternatively, you can directly edit
# `configs/experiments/squad_v1.1.yaml` to specify corresponding fields.
# Also note that the training data is the pre-processed tf_record file, while
# the validation file is the raw json file.
export PARAMS=task.train_data.input_path=$INPUT_DATA_DIR/train.tf_record
export PARAMS=$PARAMS,task.validation_data.input_path=$SQUAD_DIR/dev-v1.1.json
export PARAMS=$PARAMS,task.validation_data.vocab_file=$BERT_DIR/vocab.txt
export PARAMS=$PARAMS,task.init_checkpoint=$BERT_DIR/bert_model.ckpt
export PARAMS=$PARAMS,runtime.distribution_strategy=tpu

python3 train.py \
 --experiment=bert/squad \
 --mode=train_and_eval \
 --model_dir=$OUTPUT_DIR \
 --config_file=configs/experiments/squad_v1.1.yaml \
 --tpu=${TPU_NAME} \
 --params_override=$PARAMS

```

Note: More examples about pre-training will come soon.