Unverified Commit 726e953d authored by Marc van Zee's avatar Marc van Zee Committed by GitHub
Browse files

Improvements to Flax finetuning script (#11727)

* Add Cloud details to README

* Flax script and readme updates

* Some simplifications of Flax script
parent 86d5fb0b
......@@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer.
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing).
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) |
| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) |
| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) |
| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) |
| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) |
| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) |
| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) |
| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) |
| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) |
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
......
......@@ -123,7 +123,7 @@ def parse_args():
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.")
parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.")
args = parser.parse_args()
# Sanity checks
......@@ -148,6 +148,7 @@ def create_train_state(
learning_rate_fn: Callable[[int], float],
is_regression: bool,
num_labels: int,
weight_decay: float,
) -> train_state.TrainState:
"""Create initial training state."""
......@@ -166,8 +167,8 @@ def create_train_state(
loss_fn: Callable = struct.field(pytree_node=False)
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
def adamw(weight_decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay)
def adamw(decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay)
def traverse(fn):
def mask(data):
......@@ -183,7 +184,7 @@ def create_train_state(
tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))),
optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
)
if is_regression:
......@@ -414,7 +415,9 @@ def main():
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
)
state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels)
state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
)
# define step functions
def train_step(
......@@ -426,10 +429,10 @@ def main():
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss, logits
return loss
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grad = grad_fn(state.params)
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
......@@ -460,10 +463,11 @@ def main():
train_start = time.time()
train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
rng, input_rng = jax.random.split(rng)
# train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
rng, dropout_rng = jax.random.split(rng)
dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics)
......@@ -471,7 +475,6 @@ def main():
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
logger.info(" Evaluating...")
rng, input_rng = jax.random.split(rng)
# evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
......@@ -484,20 +487,14 @@ def main():
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# put weights on single device
state = unreplicate(state)
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(state, batch)
predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels)
# make sure weights are replicated on each device
state = replicate(state)
eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment