"vscode:/vscode.git/clone" did not exist on "b688fd858d89258ba0018e5903c2907badf49afa"
train.md 4.25 KB
Newer Older
1
2
# Model Garden NLP Common Training Driver

3
[train.py](https://github.com/tensorflow/models/blob/master/official/nlp/train.py) is the common training driver that supports multiple
4
5
6
7
8
NLP tasks (e.g., pre-training, GLUE and SQuAD fine-tuning etc) and multiple
models (e.g., BERT, ALBERT, MobileBERT etc).

## Experiment Configuration

9
[train.py] is driven by configs defined by the [ExperimentConfig](https://github.com/tensorflow/models/blob/master/official/core/config_definitions.py)
10
including configurations for `task`, `trainer` and `runtime`. The pre-defined
11
12
NLP related [ExperimentConfig](https://github.com/tensorflow/models/blob/master/official/core/config_definitions.py) can be found in
[configs/experiment_configs.py](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiment_configs.py).
13
14
15

## Experiment Registry

16
We use an [experiment registry](https://github.com/tensorflow/models/blob/master/official/core/exp_factory.py) to build a mapping
17
between experiment type to experiment configuration instance. For example,
18
[configs/finetuning_experiments.py](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
registers `bert/sentence_prediction` and `bert/squad` experiments. User can use
`--experiment` FLAG to invoke a registered experiment configuration,
e.g., `--experiment=bert/sentence_prediction`.

## Overriding Configuration via Yaml and FLAGS

The registered experiment configuration can be overridden by one or
multiple Yaml files provided by `--config_file` FLAG. For example:

```shell
--config_file=configs/experiments/glue_mnli_matched.yaml \
--config_file=configs/models/bert_en_uncased_base.yaml
```

In addition, experiment configuration can be further overriden by
`params_override` FLAG. For example:

```shell
 --params_override=task.train_data.input_path=/some/path,task.hub_module_url=/some/tfhub
```

## Run on Cloud TPUs

42
Next, we will describe how to run the [train.py](https://github.com/tensorflow/models/blob/master/official/nlp/train.py) on Cloud TPUs.
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

### Setup
First, you need to create a `tf-nightly` TPU with
[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):

```shell
export TPU_NAME=YOUR_TPU_NAME
ctpu up -name $TPU_NAME --tf-version=nightly --tpu-size=YOUR_TPU_SIZE --project=YOUR_PROJECT
```

and then install Model Garden and required dependencies:

```shell
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=$PYTHONPATH:/path/to/models
pip3 install --user -r official/requirements.txt
```

### Fine-tuning Sentence Classification with BERT from TF-Hub

This example fine-tunes BERT-base from TF-Hub on the the Multi-Genre Natural
Language Inference (MultiNLI) corpus using TPUs.

Firstly, you can prepare the fine-tuning data using
67
[`data/create_finetuning_data.py`](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_finetuning_data.py) script.
68
Resulting training and evaluation datasets in `tf_record` format will be later
69
passed to [train.py](https://github.com/tensorflow/models/blob/master/official/nlp/train.py).
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

Then you can execute the following commands to start the training and evaluation
job.

```shell
export INPUT_DATA_DIR=gs://some_bucket/datasets
export OUTPUT_DIR=gs://some_bucket/my_output_dir

# See tfhub BERT collection for more tfhub models:
# https://tfhub.dev/google/collections/bert/1
export BERT_HUB_URL=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3

# Override the configurations by FLAGS. Alternatively, you can directly edit
# `configs/experiments/glue_mnli_matched.yaml` to specify corresponding fields.
export PARAMS=task.train_data.input_path=$INPUT_DATA_DIR/mnli_train.tf_record
export PARAMS=$PARAMS,task.validation_data.input_path=$INPUT_DATA_DIR/mnli_eval.tf_record
export PARAMS=$PARAMS,task.hub_module_url=$BERT_HUB_URL
export PARAMS=$PARAMS,runtime.distribution_strategy=tpu

python3 train.py \
 --experiment=bert/sentence_prediction \
 --mode=train_and_eval \
 --model_dir=$OUTPUT_DIR \
 --config_file=configs/experiments/glue_mnli_matched.yaml \
 --tfhub_cache_dir=$OUTPUT_DIR/hub_cache \
 --tpu=${TPU_NAME} \
 --params_override=$PARAMS

```

You can monitor the training progress in the console and find the output
models in `$OUTPUT_DIR`.

Note: More examples about pre-training and fine-tuning will come soon.