-
Thien Tran authored
* set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by:
Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by:
Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by:
sanchit-gandhi <sanchit@huggingface.co> Co-authored-by:
Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
1e3c9dda