"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d266613635ad5a68609b0c9f6b79a7794a99e813"
Unverified Commit 726e953d authored by Marc van Zee's avatar Marc van Zee Committed by GitHub
Browse files

Improvements to Flax finetuning script (#11727)

* Add Cloud details to README

* Flax script and readme updates

* Some simplifications of Flax script
parent 86d5fb0b
...@@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is ...@@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer. are undertrained and we could get better results when training longer.
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing). In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics | | Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------| |-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) | | CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) | | SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) | | MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) | | STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) | | QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) | | MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) | | QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) | | RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) | | WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
......
...@@ -123,7 +123,7 @@ def parse_args(): ...@@ -123,7 +123,7 @@ def parse_args():
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.") parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.")
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -148,6 +148,7 @@ def create_train_state( ...@@ -148,6 +148,7 @@ def create_train_state(
learning_rate_fn: Callable[[int], float], learning_rate_fn: Callable[[int], float],
is_regression: bool, is_regression: bool,
num_labels: int, num_labels: int,
weight_decay: float,
) -> train_state.TrainState: ) -> train_state.TrainState:
"""Create initial training state.""" """Create initial training state."""
...@@ -166,8 +167,8 @@ def create_train_state( ...@@ -166,8 +167,8 @@ def create_train_state(
loss_fn: Callable = struct.field(pytree_node=False) loss_fn: Callable = struct.field(pytree_node=False)
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers. # Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
def adamw(weight_decay): def adamw(decay):
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay) return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay)
def traverse(fn): def traverse(fn):
def mask(data): def mask(data):
...@@ -183,7 +184,7 @@ def create_train_state( ...@@ -183,7 +184,7 @@ def create_train_state(
tx = optax.chain( tx = optax.chain(
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))), optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))), optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
) )
if is_regression: if is_regression:
...@@ -414,7 +415,9 @@ def main(): ...@@ -414,7 +415,9 @@ def main():
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
) )
state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels) state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
)
# define step functions # define step functions
def train_step( def train_step(
...@@ -426,10 +429,10 @@ def main(): ...@@ -426,10 +429,10 @@ def main():
def loss_fn(params): def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets) loss = state.loss_fn(logits, targets)
return loss, logits return loss
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn)
(loss, logits), grad = grad_fn(state.params) loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch") grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad) new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
...@@ -460,10 +463,11 @@ def main(): ...@@ -460,10 +463,11 @@ def main():
train_start = time.time() train_start = time.time()
train_metrics = [] train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3) rng, input_rng = jax.random.split(rng)
# train # train
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
rng, dropout_rng = jax.random.split(rng)
dropout_rngs = shard_prng_key(dropout_rng) dropout_rngs = shard_prng_key(dropout_rng)
state, metrics = p_train_step(state, batch, dropout_rngs) state, metrics = p_train_step(state, batch, dropout_rngs)
train_metrics.append(metrics) train_metrics.append(metrics)
...@@ -471,7 +475,6 @@ def main(): ...@@ -471,7 +475,6 @@ def main():
logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
logger.info(" Evaluating...") logger.info(" Evaluating...")
rng, input_rng = jax.random.split(rng)
# evaluate # evaluate
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
...@@ -484,20 +487,14 @@ def main(): ...@@ -484,20 +487,14 @@ def main():
# make sure leftover batch is evaluated on one device # make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0: if num_leftover_samples > 0 and jax.process_index() == 0:
# put weights on single device
state = unreplicate(state)
# take leftover samples # take leftover samples
batch = eval_dataset[-num_leftover_samples:] batch = eval_dataset[-num_leftover_samples:]
batch = {k: jnp.array(v) for k, v in batch.items()} batch = {k: jnp.array(v) for k, v in batch.items()}
labels = batch.pop("labels") labels = batch.pop("labels")
predictions = eval_step(state, batch) predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels) metric.add_batch(predictions=predictions, references=labels)
# make sure weights are replicated on each device
state = replicate(state)
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}") logger.info(f" Done! Eval metrics: {eval_metric}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment