Unverified Commit bd987165 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Align GLUE training script with mlm training script (#11778)



* speed up flax glue

* remove unnecessary line

* remove folder

* remove run in loop
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 22394387
...@@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is ...@@ -59,20 +59,19 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer. are undertrained and we could get better results when training longer.
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing). In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing).
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) | | CoLA | Matthew's corr | 60.82 | 59.04 | 1.17 | [tfhub.dev](https://tensorboard.dev/experiment/U2ncNFP3RpWW6YnA9PYJBA/) |
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) | | SST-2 | Accuracy | 92.43 | 92.13 | 0.38 | [tfhub.dev](https://tensorboard.dev/experiment/vzxoOHZURcm0rO1I33x7uA/) |
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) | | MRPC | F1/Accuracy | 89.90/88.98 | 88.98/85.30 | 0.73/2.33 | [tfhub.dev](https://tensorboard.dev/experiment/EWPBIbfYSDGHjiYxrw2a2Q/) |
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) | | STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/3aYHKL10TeiaZYwH1M8ogA/) |
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) | | QQP | Accuracy/F1 | 90.82/87.54 | 90.75/87.53 | 0.06/0.02 | [tfhub.dev](https://tensorboard.dev/experiment/VfVDLS4AQnqr4NMbng6yUw/) |
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) | | MNLI | Matched acc. | 84.10 | 83.84 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/Sz9UdhoORaaSjzuOHRB4Jw/) |
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) | | QNLI | Accuracy | 91.07 | 90.83 | 0.19 | [tfhub.dev](https://tensorboard.dev/experiment/zk6udb5MQAyAQ4eczrFBaQ/) |
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) | | RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/BwxaUoAEQ5aa3oQilEjADw/) |
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) | | WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/b2Y8ouwMTRC8iBWzRzVYTA/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
...@@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https ...@@ -85,18 +84,18 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) | | Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|-------|-----------|------------|------------|-----------------| |-------|-----------|------------|------------|-----------------|
| CoLA | 1m 46s | 1m 26s | 3m 9s | 4m 6s | | CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s |
| SST-2 | 5m 30s | 6m 28s | 22m 33s | 34m 37s | | SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s |
| MRPC | 1m 32s | 1m 14s | 2m 20s | 2m 56s | | MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s |
| STS-B | 1m 33s | 1m 12s | 2m 16s | 2m 48s | | STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s |
| QQP | 24m 40s | 31m 48s | 1h 59m 41s | 2h 54m | | QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m |
| MNLI | 26m 30s | 33m 55s | 2h 9m 37s | 3h 7m 6s | | MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
| QNLI | 8m | 9m 40s | 34m 40s | 49m 8s | | QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s |
| RTE | 1m 21s | 55s | 1m 10s | 1m 16s | | RTE | 1m 20s | 55s | 1m 10s | 1m 16s |
| WNLI | 1m 12s | 48s | 39s | 36s | | WNLI | 1m 11s | 48s | 39s | 36s |
|-------| |-------|
| **TOTAL** | 1h 13m | 1h 28m | 5h 16m | 6h 37m | | **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
| **COST*** | $9.60 | $29.10 | $13.06 | $16.41 | | **COST*** | $8.56 | $29.10 | $13.06 | $16.41 |
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices *All experiments are ran on Google Cloud Platform. Prices are on-demand prices
......
...@@ -34,7 +34,7 @@ from flax import struct, traverse_util ...@@ -34,7 +34,7 @@ from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard from flax.metrics import tensorboard
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
...@@ -407,6 +407,7 @@ def main(): ...@@ -407,6 +407,7 @@ def main():
num_epochs = int(args.num_train_epochs) num_epochs = int(args.num_train_epochs)
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = args.per_device_train_batch_size * jax.local_device_count() train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
...@@ -424,6 +425,7 @@ def main(): ...@@ -424,6 +425,7 @@ def main():
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]: ) -> Tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels") targets = batch.pop("labels")
def loss_fn(params): def loss_fn(params):
...@@ -436,7 +438,7 @@ def main(): ...@@ -436,7 +438,7 @@ def main():
grad = jax.lax.pmean(grad, "batch") grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad) new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics return new_state, metrics, new_dropout_rng
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
...@@ -467,9 +469,7 @@ def main(): ...@@ -467,9 +469,7 @@ def main():
# train # train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
rng, dropout_rng = jax.random.split(rng) state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics) train_metrics.append(metrics)
train_time += time.time() - train_start train_time += time.time() - train_start
logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment