Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more :hugs: Signed-off-by:Morgan Funtowicz <morgan@huggingface.co> * Initial working (forward + backward) for Flax MLM training example. Signed-off-by:
Morgan Funtowicz <morgan@huggingface.co> * Simply code Signed-off-by:
Morgan Funtowicz <morgan@huggingface.co> * Addressing comments, using module and moving to LM task. Signed-off-by:
Morgan Funtowicz <morgan@huggingface.co> * Restore parameter name "module" wrongly renamed model. Signed-off-by:
Morgan Funtowicz <morgan@huggingface.co> * Restore correct output ordering... Signed-off-by:
Morgan Funtowicz <morgan@huggingface.co> * Actually commit the example
😅 Signed-off-by:Morgan Funtowicz <morgan@huggingface.co> * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Make it possible to initialize the training from scratch Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Reuse flax linen example of cross entropy loss Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Added specific data collator for flax Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Remove todo for data collator Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Added evaluation step Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Enable flax tensorboard output Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Enable jax.pmap support. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Enable bfloat16 with --fp16 cmdline args Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Correctly export metrics to tensorboard Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Added dropout and ability to use it. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Effectively enable & disable during training and evaluation steps. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Oops. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Enable specifying kernel initializer scale Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Added warmup step to the learning rate scheduler. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Fix typo. Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Print training loss Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Make style Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * fix linter issue (flake8) Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Fix model matching Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Fix dummies Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Fix non default dtype on Flax models Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta attention as Bert Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * fix copy Signed-off-by:
Morgan Funtowicz <funtowiczmo@gmail.com> * Wording. Co-authored-by:
Marc van Zee <marcvanzee@gmail.com> Co-authored-by:
Marc van Zee <marcvanzee@gmail.com>
Showing
Please register or sign in to comment