"src/vscode:/vscode.git/clone" did not exist on "646659dd9927b8624179765ec1ce884a7bdae578"
  1. 09 Sep, 2022 1 commit
  2. 12 Aug, 2022 1 commit
  3. 01 Aug, 2022 1 commit
  4. 01 Jul, 2022 1 commit
    • Sanchit Gandhi's avatar
      [Flax] Add remat (gradient checkpointing) (#17843) · 485bbe79
      Sanchit Gandhi authored
      * [Flax] Add remat (gradient checkpointing)
      
      * fix variable naming in test
      
      * flip: checkpoint using a method
      
      * fix naming
      
      * fix class naming
      
      * apply PVP's suggestions from code review
      
      * make fix-copies
      
      * fix big-bird, electra, roberta
      
      * cookie-cutter
      
      * fix flax big-bird
      
      * move test to common
      485bbe79
  5. 22 Jun, 2022 1 commit
  6. 21 Jun, 2022 2 commits
  7. 19 Apr, 2022 1 commit
    • Suraj Patil's avatar
      [Flax] improve large model init and loading (#16148) · d3bd9ac7
      Suraj Patil authored
      
      
      * begin do_init
      
      * add params_shape_tree
      
      * raise error if params are accessed when do_init is False
      
      * don't allow do_init=False when keys are missing
      
      * make shape tree a property
      
      * assign self._params at the end
      
      * add test for do_init
      
      * add do_init arg to all flax models
      
      * fix param setting
      
      * disbale do_init for composite models
      
      * update test
      
      * add do_init in FlaxBigBirdForMultipleChoice
      
      * better names and errors
      
      * improve test
      
      * style
      
      * add a warning when do_init=False
      
      * remove extra if
      
      * set params after _required_params
      
      * add test for from_pretrained
      
      * do_init => _do_init
      
      * chage warning to info
      
      * fix typo
      
      * add params in init_weights
      
      * add params to gpt neo init
      
      * add params to init_weights
      
      * update do_init test
      
      * Trigger CI
      
      * Apply suggestions from code review
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * update template
      
      * trigger CI
      
      * style
      
      * style
      
      * fix template
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      d3bd9ac7
  8. 29 Mar, 2022 1 commit
  9. 18 Mar, 2022 1 commit
  10. 09 Feb, 2022 1 commit
  11. 20 Dec, 2021 1 commit
  12. 17 Dec, 2021 1 commit
    • Daniel Stancl's avatar
      Implement head_mask for Flax BERT and other models copied from BERT (#14620) · ff066119
      Daniel Stancl authored
      * Implement head_mask for Flax BERT and other models copied from BERT
      
      * Remove `from jax._src.nn.functions import sigmoid`
      
      Remove `from jax._src.nn.functions import sigmoid` unintentionally added by IDE
      
      * Remove no more valid copy statement
      
      * Apply patil-suraj's suggestions from code review
      
      * Apply suggestions from the code review
      
      * Update Flax template
      
      * Fix a typo
      
      * Also update template for CausalLM modules
      ff066119
  13. 11 Nov, 2021 2 commits
    • Suraj Patil's avatar
      fix loading flax bf16 weights in pt (#14369) · 3d607df8
      Suraj Patil authored
      
      
      * fix loading flax bf16 weights in pt
      
      * fix clip test
      
      * fix t5 test
      
      * add logging statement
      
      * Update src/transformers/modeling_flax_pytorch_utils.py
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * switch back to native any
      
      * fix check for bf16 weights
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      3d607df8
    • Suraj Patil's avatar
      Fix Flax params dtype (#13098) · e92190c0
      Suraj Patil authored
      
      
      * fix inits
      
      * fix embed dtype
      
      * fix embed dtype
      
      * add test to check default dtype
      
      * quality
      
      * add type conversion methods for flax models
      
      * more robust casting
      
      * cast sinusoidal positions
      
      * update pegasus
      
      * update albert
      
      * update test
      
      * make sure dtype is passed to every module
      
      * style
      
      * fix electra dense
      
      * fix t5
      
      * quality
      
      * add more tests
      
      * better name
      
      * use the dtype for lm head computation
      
      * fix albert
      
      * style
      
      * fix albert embed dtype
      
      * more tests
      
      * fix vision enc-dec
      
      * cleanup
      
      * fix embed dtype pegasus
      
      * fix default param test
      
      * doc
      
      * update template
      
      * fix final_logits_bias dtype
      
      * Apply suggestions from code review
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * fix doc
      
      * fix doc
      
      * add detailed docstring for dtype parameter
      
      * remove un-necessary import
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      e92190c0
  14. 02 Nov, 2021 1 commit
  15. 21 Oct, 2021 1 commit
  16. 12 Aug, 2021 1 commit
  17. 05 Aug, 2021 1 commit
  18. 04 Aug, 2021 1 commit
  19. 13 Jul, 2021 1 commit
  20. 23 Jun, 2021 1 commit
  21. 21 Jun, 2021 2 commits
  22. 14 Jun, 2021 3 commits
    • Vasudev Gupta's avatar
      Flax Big Bird (#11967) · d9c0d08f
      Vasudev Gupta authored
      
      
      * add flax bert
      
      * bert -> bigbird
      
      * original_full ported
      
      * add debugger
      
      * init block sparse
      
      * fix copies ; gelu_fast -> gelu_new
      
      * block sparse port
      
      * fix block sparse
      
      * block sparse working
      
      * all ckpts working
      
      * fix-copies
      
      * make quality
      
      * init tests
      
      * temporary fix for FlaxBigBirdForMultipleChoice
      
      * skip test_attention_outputs
      
      * fix
      
      * gelu_fast -> gelu_new ; fix multiple choice model
      
      * remove nsp
      
      * fix sequence classifier
      
      * fix
      
      * make quality
      
      * make fix-copies
      
      * finish
      
      * Delete debugger.ipynb
      
      * Update src/transformers/models/big_bird/modeling_flax_big_bird.py
      
      * make style
      
      * finish
      
      * bye bye jit flax tests
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      d9c0d08f
    • Patrick von Platen's avatar
      [Flax] Fix flax pt equivalence tests (#12154) · 007be9e4
      Patrick von Platen authored
      * fix_torch_device_generate_test
      
      * remove @
      
      * upload
      007be9e4
    • Daniel Stancl's avatar
      FlaxBart (#11537) · 4a51b1dd
      Daniel Stancl authored
      
      
      * Start working on FlaxBart
      
      * Create modeling_flax_bart.py
      
      * Write FlaxBartAttention
      
      * Add FlaxBartEncoderLayer
      
      * Add FlaxBartDecoderLayer and some typing
      
      * Add helepr function for FlaxBart
      
      * shift_tokens_right
      
      * _make_causal_mask
      
      * _expand_mask
      
      * Add PositionalEmbedding and fix init_std naming
      
      * Add FlaxBartPretrainedModel
      
      * Add FlaxBartEncoder
      
      * Add FlaxBartEncoder
      
      * Add FlaxBartEncoder among modules to be imported
      
      * YET WE CANNOT INITIALIZE THAT!! :(
      
      * Make BartEncoder working
      
      Change BartEncoder to instance of nn.Module so far
      
      * Add FlaxBartDecoder
      
      * Add FlaxBartModel
      
      * TODO to make model run -> Prepapre model inputs
      
      * Resolve padding
      
      * Add FlaxBartModel
      
      * Add FlaxBartModel into importable modules
      
      * Remove FlaxBartEncoder and FlaxBartDecoder from importable modules
      
      * make style; not properly working
      
      * make style; make quality not pass due to some import I left
      
      * Remove TODO for padding_idx in nn.Embed so far
      
      * Add FlaxBartForConditionalGeneration
      
      * Incorporate Flax model output classes, i.e. return_dict
      
      * Add another models and incorporate use_cache arg
      
      * Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering
      
      * Incorporate use_cache arg from PyTorch implementation
      
      * Add all necessary Flax output utils
      
      * Add FlaxBartForCausalLM; not working yet'
      
      * Add minor improvements; still lacks some functionality
      
      * Update docs, src and tests
      
      * Add support of FlaxBart to docs/source
      
      * Fix some bugs in FlaxBart souce code
      
      * Add some neccessary tests for FlaxBart models - jit_compilation not passing
      
      * Fix tests and add test_head_masking
      
      * Fix tests for @jax.jit computation
      
      * Add test_head_masking
      
      * Migrate FlaxBart tests from jax.numpy to numpy
      
      * Remove FlaxBartForCausalLM
      
      * Clean repo
      
      * fix bart model weight structure
      
      * Fix FlaxBartForSequenceClassification
      
      Slicing is not possible to use below jit, therefore, selecting sentence
      representation from hidden_states must be changed.
      
      * Allow FlaxBartForSequenceClassification for testing pt_flax equivalence
      
      * Allow testing for FlaxBartForQA for pt_flax equivalence
      
      * Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6
      
      * remove past_key_values
      
      * remove inputs_mebeds and make input_ids required
      
      * add position ids
      
      * re-write attention layer
      
      * fix dataclass
      
      * fix pos embeds and attention output
      
      * fix pos embeds
      
      * expose encode method
      
      * expose decode method
      
      * move docstring to top
      
      * add cache for causal attn layer
      
      * remove head masking for now
      
      * s2s greedy search first pass
      
      * boom boom
      
      * fix typos
      
      * fix greedy generate for bart
      
      * use encoder, decoder layers instead of num_hidden_layers
      
      * handle encoder_outputs
      
      * cleanup
      
      * simplify decoding
      
      * more clean-up
      
      * typos
      
      * Change header + add {decoder_,}position_ids into 2 models
      
      * add BartConfig
      
      * fix existing tests
      
      * add encode, decode methods
      
      * Fix shift_tokens_right for JIT compilation + clarify one condition
      
      * fix decode
      
      * encoder => encode
      
      * simplify generate
      
      * add tests for encode and decode
      
      * style
      
      * add tests for cache
      
      * fix equivalence tests
      
      * sample generate now works with seq2seq
      
      * generation tests
      
      * initialize dense layers
      
      * docstring and cleanup
      
      * quality
      
      * remove get/set input_embeddings
      
      * address Patricks suggestions
      
      * decode for every model, remove encoder_outputs from call
      
      * update tests accordingly
      
      * decode returns only decoder outputs and logits
      
      * fix arguments
      
      * doc encode, decode methods
      
      * correct base_model_prefix
      
      * fix test for seq classif model
      
      * fix docs
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      4a51b1dd
  23. 01 Jun, 2021 1 commit
    • Suraj Patil's avatar
      Add FlaxCLIP (#11883) · ad25fd62
      Suraj Patil authored
      * add flax CLIP
      
      * default input_shape
      
      * add tests
      
      * fix test
      
      * fix name
      
      * fix docs
      
      * fix shapes
      
      * attend at least 1 token
      
      * flax conv to torch conv
      
      * return floats
      
      * fix equivalence tests
      
      * fix import
      
      * return attention_weights and update tests
      
      * fix dosctrings
      
      * address patricks comments
      
      * input_shape arg
      
      * add tests for get_image_features and get_text_features methods
      
      * fix tests
      ad25fd62
  24. 28 May, 2021 1 commit
  25. 26 May, 2021 1 commit
  26. 18 May, 2021 1 commit
    • Suraj Patil's avatar
      FlaxGPT2 (#11556) · ca33278f
      Suraj Patil authored
      
      
      * flax gpt2
      
      * combine masks
      
      * handle shared embeds
      
      * add causal LM sample
      
      * style
      
      * add tests
      
      * style
      
      * fix imports, docs, quality
      
      * don't use cache
      
      * add cache
      
      * add cache 1st version
      
      * make use cache work
      
      * start adding test for generation
      
      * finish generation loop compilation
      
      * rewrite test
      
      * finish
      
      * update
      
      * update
      
      * apply sylvains suggestions
      
      * update
      
      * refactor
      
      * fix typo
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      ca33278f
  27. 04 May, 2021 1 commit
  28. 29 Apr, 2021 1 commit
  29. 23 Apr, 2021 1 commit
    • Patrick von Platen's avatar
      [Flax] Big FlaxBert Refactor (#11364) · 8c9b5fcb
      Patrick von Platen authored
      * improve flax
      
      * refactor
      
      * typos
      
      * Update src/transformers/modeling_flax_utils.py
      
      * Apply suggestions from code review
      
      * Update src/transformers/modeling_flax_utils.py
      
      * fix typo
      
      * improve error tolerance
      
      * typo
      
      * correct nasty saving bug
      
      * fix from pretrained
      
      * correct tree map
      
      * add note
      
      * correct weight tying
      8c9b5fcb
  30. 31 Mar, 2021 1 commit
  31. 30 Mar, 2021 1 commit
  32. 18 Mar, 2021 1 commit
    • Patrick von Platen's avatar
      [Flax] Adapt Flax models to new structure (#9484) · 0b98ca36
      Patrick von Platen authored
      
      
      * Create modeling_flax_eletra with code copied from modeling_flax_bert
      
      * Add ElectraForMaskedLM and ElectraForPretraining
      
      * Add modeling test for Flax electra and fix naming and arg in Flax Electra model
      
      * Add documentation
      
      * Fix code style
      
      * Create modeling_flax_eletra with code copied from modeling_flax_bert
      
      * Add ElectraForMaskedLM and ElectraForPretraining
      
      * Add modeling test for Flax electra and fix naming and arg in Flax Electra model
      
      * Add documentation
      
      * Fix code style
      
      * Fix code quality
      
      * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016
      
      * Remove redundant ElectraPooler
      
      * save intermediate
      
      * adapt
      
      * correct bert flax design
      
      * adapt roberta as well
      
      * finish roberta flax
      
      * finish
      
      * apply suggestions
      
      * apply suggestions
      Co-authored-by: default avatarChris Nguyen <anhtu2687@gmail.com>
      0b98ca36
  33. 16 Mar, 2021 1 commit
  34. 16 Dec, 2020 1 commit
    • Patrick von Platen's avatar
      [Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054) · 640e6fe1
      Patrick von Platen authored
      
      
      * save intermediate
      
      * save intermediate
      
      * save intermediate
      
      * correct flax bert model file
      
      * new module / model naming
      
      * make style
      
      * almost finish BERT
      
      * finish roberta
      
      * make fix-copies
      
      * delete keys file
      
      * last refactor
      
      * fixes in run_mlm_flax.py
      
      * remove pooled from run_mlm_flax.py`
      
      * fix gelu | gelu_new
      
      * remove Module from inits
      
      * splits
      
      * dirty print
      
      * preventing warmup_steps == 0
      
      * smaller splits
      
      * make fix-copies
      
      * dirty print
      
      * dirty print
      
      * initial_evaluation argument
      
      * declaration order fix
      
      * proper model initialization/loading
      
      * proper initialization
      
      * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug
      
      * removed tokenizers warning hack, fixed model re-initialization
      
      * reverted training_args.py changes
      
      * fix flax from pretrained
      
      * improve test in flax
      
      * apply sylvains tips
      
      * update init
      
      * make 0.3.0 compatible
      
      * revert tevens changes
      
      * revert tevens changes 2
      
      * finalize revert
      
      * fix bug
      
      * add docs
      
      * add pretrained to init
      
      * Update src/transformers/modeling_flax_utils.py
      
      * fix copies
      
      * final improvements
      Co-authored-by: default avatarTevenLeScao <teven.lescao@gmail.com>
      640e6fe1
  35. 10 Dec, 2020 1 commit