Unverified Commit d1f5ca1a authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

[FLAX] glue training example refactor (#13815)

* refactor run_flax_glue.py

* updated readme

* rm unused import and args typo fix

* refactor

* make consistent arg name across task

* has_tensorboard check

* argparse -> argument dataclasses

* refactor according to review

* fix
parent db350394
...@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus): ...@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus):
--per_device_train_batch_size=2 --per_device_train_batch_size=2
--per_device_eval_batch_size=1 --per_device_eval_batch_size=1
--learning_rate=1e-4 --learning_rate=1e-4
--max_train_steps=10 --eval_steps=2
--num_warmup_steps=2 --warmup_steps=2
--seed=42 --seed=42
--max_length=128 --max_seq_length=128
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
......
...@@ -33,15 +33,16 @@ export TASK_NAME=mrpc ...@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
python run_flax_glue.py \ python run_flax_glue.py \
--model_name_or_path bert-base-cased \ --model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
--max_length 128 \ --max_seq_length 128 \
--learning_rate 2e-5 \ --learning_rate 2e-5 \
--num_train_epochs 3 \ --num_train_epochs 3 \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \ --output_dir ./$TASK_NAME/ \
--push_to_hub --push_to_hub
``` ```
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli. where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch. Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`. Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment