Commit bc8e6d24 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Docs] Add BERT pre-training experiment documentation to train.md

https://github.com/tensorflow/models/issues/10074

PiperOrigin-RevId: 463195812
parent 0f182173
task:
init_checkpoint: ''
model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
train_data:
drop_remainder: true
global_batch_size: 512
input_path: '[Your proceed wiki data path]*,[Your proceed books data path]*'
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
validation_data:
drop_remainder: false
global_batch_size: 512
input_path: '[Your proceed wiki data path]-00000-of-00500,[Your proceed books data path]-00000-of-00500'
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
use_v2_feature_names: true
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
......@@ -229,6 +229,73 @@ python3 train.py \
```
### Pre-train a BERT from scratch
</details>
Note: More examples about pre-training will come soon.
This example pre-trains a BERT model with Wikipedia and Books datasets used by
the original BERT paper.
The [BERT repo](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_pretraining_data.py)
contains detailed information about the Wikipedia dump and
[BookCorpus](https://yknzhu.wixsite.com/mbweb). Of course, the pre-training
recipe is generic and you can apply the same recipe to your own corpus.
Please use the script
[`create_pretraining_data.py`](https://github.com/tensorflow/models/blob/master/official/nlp/data/create_pretraining_data.py)
which is essentially branched from [BERT research repo](https://github.com/google-research/bert)
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
Running the pre-training script requires an input and output directory, as well
as a vocab file. Note that `max_seq_length` will need to match the sequence
length parameter you specify when you run pre-training.
```shell
export WORKING_DIR='local disk or cloud location'
export BERT_DIR='local disk or cloud location'
python models/official/nlp/data/create_pretraining_data.py \
--input_file=$WORKING_DIR/input/input.txt \
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
```
Then, you can update the yaml configuration file, e.g.
`configs/experiments/wiki_books_pretrain.yaml` to specify your data paths and
update masking-related hyper parameters to match with your specification for
the pretraining data. When your data have multiple shards, you can
use `*` to include multiple files.
To train different BERT sizes, you need to adjust:
```
model:
cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
```
to match the hidden dimensions.
Then, you can start the training and evaluation jobs, which runs the
[`bert/pretraining`](https://github.com/tensorflow/models/blob/master/official/nlp/configs/pretraining_experiments.py#L51)
experiment:
```shell
export OUTPUT_DIR=gs://some_bucket/my_output_dir
export PARAMS=$PARAMS,runtime.distribution_strategy=tpu
python3 train.py \
--experiment=bert/pretraining \
--mode=train_and_eval \
--model_dir=$OUTPUT_DIR \
--config_file=configs/models/bert_en_uncased_base.yaml \
--config_file=configs/experiments/wiki_books_pretrain.yaml \
--tpu=${TPU_NAME} \
--params_override=$PARAMS
```
Note: More examples about pre-training with TFDS datesets will come soon.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment