Unverified Commit d1f5ca1a authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

[FLAX] glue training example refactor (#13815)

* refactor run_flax_glue.py

* updated readme

* rm unused import and args typo fix

* refactor

* make consistent arg name across task

* has_tensorboard check

* argparse -> argument dataclasses

* refactor according to review

* fix
parent db350394
......@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus):
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--max_train_steps=10
--num_warmup_steps=2
--eval_steps=2
--warmup_steps=2
--seed=42
--max_length=128
--max_seq_length=128
""".split()
with patch.object(sys, "argv", testargs):
......
......@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \
--max_length 128 \
--max_seq_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \
--push_to_hub
```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment