Unverified Commit 2d70c912 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Adapt flax examples to include `push_to_hub` (#12391)



* fix_torch_device_generate_test

* remove @

* finish

* correct summary writer

* correct push to hub

* fix indent

* finish

* finish

* finish

* finish

* finish
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent a7d0b288
...@@ -33,11 +33,37 @@ in Norwegian on a single TPUv3-8 pod. ...@@ -33,11 +33,37 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a folder to save the trained model and a symbolic link to the `run_mlm_flax.py` script. Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-roberta-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd norwegian-roberta-base
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_mlm_flax.py`.
```bash ```bash
export MODEL_DIR="./norwegian-roberta-base" export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
``` ```
...@@ -98,7 +124,7 @@ Next we can run the example script to pretrain the model: ...@@ -98,7 +124,7 @@ Next we can run the example script to pretrain the model:
```bash ```bash
./run_mlm_flax.py \ ./run_mlm_flax.py \
--output_dir="./runs" \ --output_dir="${MODEL_DIR}" \
--model_type="roberta" \ --model_type="roberta" \
--config_name="${MODEL_DIR}" \ --config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \ --tokenizer_name="${MODEL_DIR}" \
...@@ -114,7 +140,8 @@ Next we can run the example script to pretrain the model: ...@@ -114,7 +140,8 @@ Next we can run the example script to pretrain the model:
--pad_to_max_length \ --pad_to_max_length \
--num_train_epochs="18" \ --num_train_epochs="18" \
--adam_beta1="0.9" \ --adam_beta1="0.9" \
--adam_beta2="0.98" --adam_beta2="0.98" \
--push_to_hub
``` ```
Training should converge at a loss and accuracy Training should converge at a loss and accuracy
...@@ -135,11 +162,37 @@ in Norwegian on a single TPUv3-8 pod. ...@@ -135,11 +162,37 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a folder to save the trained model and a symbolic link to the `run_clm_flax.py` script. Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-gpt2
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-gpt2
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd norwegian-gpt2
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_clm_flax.py`.
```bash ```bash
export MODEL_DIR="./norwegian-gpt2" export MODEL_DIR="./norwegian-gpt2"
mkdir -p ${MODEL_DIR}
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
``` ```
...@@ -166,7 +219,7 @@ Next we can run the example script to pretrain the model: ...@@ -166,7 +219,7 @@ Next we can run the example script to pretrain the model:
```bash ```bash
./run_clm_flax.py \ ./run_clm_flax.py \
--output_dir="./runs" \ --output_dir="${MODEL_DIR}" \
--model_type="gpt2" \ --model_type="gpt2" \
--config_name="${MODEL_DIR}" \ --config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \ --tokenizer_name="${MODEL_DIR}" \
...@@ -180,6 +233,7 @@ Next we can run the example script to pretrain the model: ...@@ -180,6 +233,7 @@ Next we can run the example script to pretrain the model:
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="20" \ --num_train_epochs="20" \
--push_to_hub
``` ```
Training should converge at a loss and perplexity Training should converge at a loss and perplexity
...@@ -197,14 +251,9 @@ For reproducibility, we state the training commands used for PyTorch/XLA and PyT ...@@ -197,14 +251,9 @@ For reproducibility, we state the training commands used for PyTorch/XLA and PyT
| Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) | | Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) |
|-------|-----------|------------|------------| |-------|-----------|------------|------------|
| MLM | 15h32m | 23h46m | 44h14m | | MLM | 15h32m | 23h46m | 44h14m |
| **COST*** | $124.24 | $187.84 | $877.92 |
*All experiments are ran on Google Cloud Platform.
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices GPU experiments are ran without further optimizations besides JAX
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
the following tables:
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
V100 GPU). GPU experiments are ran without further optimizations besides JAX
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8" transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips. are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
...@@ -281,7 +330,7 @@ mkdir -p ${MODEL_DIR} ...@@ -281,7 +330,7 @@ mkdir -p ${MODEL_DIR}
```bash ```bash
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \ python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
--output_dir="./runs" \ --output_dir="${MODEL_DIR}" \
--model_type="roberta" \ --model_type="roberta" \
--config_name="${MODEL_DIR}" \ --config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \ --tokenizer_name="${MODEL_DIR}" \
......
...@@ -451,7 +451,7 @@ def main(): ...@@ -451,7 +451,7 @@ def main():
# Enable tensorboard only on the master node # Enable tensorboard only on the master node
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
# Initialize our training # Initialize our training
rng = jax.random.PRNGKey(training_args.seed) rng = jax.random.PRNGKey(training_args.seed)
...@@ -604,10 +604,15 @@ def main(): ...@@ -604,10 +604,15 @@ def main():
cur_step = epoch * (len(train_dataset) // train_batch_size) cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save last checkpoint # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params) model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -269,7 +269,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar ...@@ -269,7 +269,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
return batch_idx return batch_idx
def write_metric(train_metrics, eval_metrics, train_time, step): def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step) summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics) train_metrics = get_metrics(train_metrics)
...@@ -472,7 +472,7 @@ if __name__ == "__main__": ...@@ -472,7 +472,7 @@ if __name__ == "__main__":
# Enable tensorboard only on the master node # Enable tensorboard only on the master node
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
# Data collator # Data collator
# This one will take care of randomly masking the tokens. # This one will take care of randomly masking the tokens.
...@@ -642,9 +642,14 @@ if __name__ == "__main__": ...@@ -642,9 +642,14 @@ if __name__ == "__main__":
# Save metrics # Save metrics
if has_tensorboard and jax.process_index() == 0: if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
write_metric(train_metrics, eval_metrics, train_time, cur_step) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save last checkpoint # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params) model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
...@@ -542,7 +542,7 @@ def main(): ...@@ -542,7 +542,7 @@ def main():
try: try:
from flax.metrics.tensorboard import SummaryWriter from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie: except ImportError as ie:
has_tensorboard = False has_tensorboard = False
logger.warning( logger.warning(
...@@ -787,10 +787,15 @@ def main(): ...@@ -787,10 +787,15 @@ def main():
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc) logger.info(desc)
# save last checkpoint # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params) model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,31 +23,68 @@ Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transfor ...@@ -23,31 +23,68 @@ Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transfor
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models). Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-glue-mrpc-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd bert-glue-mrpc-test
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_flax_glue.py`.
```bash ```bash
export TASK_NAME=mrpc export TASK_NAME=mrpc
export MODEL_DIR="./bert-glue-mrpc-test"
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
```
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
```bash
python run_flax_glue.py \ python run_flax_glue.py \
--model_name_or_path bert-base-cased \ --model_name_or_path bert-base-cased \
--task_name $TASK_NAME \ --task_name ${TASK_NAME} \
--max_length 128 \ --max_length 128 \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3 \ --num_train_epochs 3 \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--output_dir /tmp/$TASK_NAME/ --output_dir ${MODEL_DIR} \
--push_to_hub
``` ```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli. where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch. Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `---output_dir`. Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
You can see the results by running `tensorboard` in that directory: You can see the results by running `tensorboard` in that directory:
```bash ```bash
$ tensorboard --logdir . $ tensorboard --logdir .
``` ```
or directly on the hub under *Training metrics*.
### Accuracy Evaluation ### Accuracy Evaluation
We train five replicas and report mean accuracy and stdev on the dev set below. We train five replicas and report mean accuracy and stdev on the dev set below.
...@@ -95,14 +132,8 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https ...@@ -95,14 +132,8 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
| WNLI | 1m 11s | 48s | 39s | 36s | | WNLI | 1m 11s | 48s | 39s | 36s |
|-------| |-------|
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m | | **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
| **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices *All experiments are ran on Google Cloud Platform.
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using GPU experiments are ran without further optimizations besides JAX
the following tables:
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
V100 GPU). GPU experiments are ran without further optimizations besides JAX
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8" transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips. are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
...@@ -123,6 +123,11 @@ def parse_args(): ...@@ -123,6 +123,11 @@ def parse_args():
) )
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.") parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
parser.add_argument(
"--push_to_hub",
action="store_true",
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
)
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -491,10 +496,15 @@ def main(): ...@@ -491,10 +496,15 @@ def main():
cur_step = epoch * (len(train_dataset) // train_batch_size) cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(train_metrics, eval_metric, train_time, cur_step) write_metric(train_metrics, eval_metric, train_time, cur_step)
# save last checkpoint # save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(args.output_dir, params=params) model.save_pretrained(
args.output_dir,
params=params,
push_to_hub=args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch}",
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment