• Daniel Stancl's avatar
    FlaxBart (#11537) · 4a51b1dd
    Daniel Stancl authored
    
    
    * Start working on FlaxBart
    
    * Create modeling_flax_bart.py
    
    * Write FlaxBartAttention
    
    * Add FlaxBartEncoderLayer
    
    * Add FlaxBartDecoderLayer and some typing
    
    * Add helepr function for FlaxBart
    
    * shift_tokens_right
    
    * _make_causal_mask
    
    * _expand_mask
    
    * Add PositionalEmbedding and fix init_std naming
    
    * Add FlaxBartPretrainedModel
    
    * Add FlaxBartEncoder
    
    * Add FlaxBartEncoder
    
    * Add FlaxBartEncoder among modules to be imported
    
    * YET WE CANNOT INITIALIZE THAT!! :(
    
    * Make BartEncoder working
    
    Change BartEncoder to instance of nn.Module so far
    
    * Add FlaxBartDecoder
    
    * Add FlaxBartModel
    
    * TODO to make model run -> Prepapre model inputs
    
    * Resolve padding
    
    * Add FlaxBartModel
    
    * Add FlaxBartModel into importable modules
    
    * Remove FlaxBartEncoder and FlaxBartDecoder from importable modules
    
    * make style; not properly working
    
    * make style; make quality not pass due to some import I left
    
    * Remove TODO for padding_idx in nn.Embed so far
    
    * Add FlaxBartForConditionalGeneration
    
    * Incorporate Flax model output classes, i.e. return_dict
    
    * Add another models and incorporate use_cache arg
    
    * Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering
    
    * Incorporate use_cache arg from PyTorch implementation
    
    * Add all necessary Flax output utils
    
    * Add FlaxBartForCausalLM; not working yet'
    
    * Add minor improvements; still lacks some functionality
    
    * Update docs, src and tests
    
    * Add support of FlaxBart to docs/source
    
    * Fix some bugs in FlaxBart souce code
    
    * Add some neccessary tests for FlaxBart models - jit_compilation not passing
    
    * Fix tests and add test_head_masking
    
    * Fix tests for @jax.jit computation
    
    * Add test_head_masking
    
    * Migrate FlaxBart tests from jax.numpy to numpy
    
    * Remove FlaxBartForCausalLM
    
    * Clean repo
    
    * fix bart model weight structure
    
    * Fix FlaxBartForSequenceClassification
    
    Slicing is not possible to use below jit, therefore, selecting sentence
    representation from hidden_states must be changed.
    
    * Allow FlaxBartForSequenceClassification for testing pt_flax equivalence
    
    * Allow testing for FlaxBartForQA for pt_flax equivalence
    
    * Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6
    
    * remove past_key_values
    
    * remove inputs_mebeds and make input_ids required
    
    * add position ids
    
    * re-write attention layer
    
    * fix dataclass
    
    * fix pos embeds and attention output
    
    * fix pos embeds
    
    * expose encode method
    
    * expose decode method
    
    * move docstring to top
    
    * add cache for causal attn layer
    
    * remove head masking for now
    
    * s2s greedy search first pass
    
    * boom boom
    
    * fix typos
    
    * fix greedy generate for bart
    
    * use encoder, decoder layers instead of num_hidden_layers
    
    * handle encoder_outputs
    
    * cleanup
    
    * simplify decoding
    
    * more clean-up
    
    * typos
    
    * Change header + add {decoder_,}position_ids into 2 models
    
    * add BartConfig
    
    * fix existing tests
    
    * add encode, decode methods
    
    * Fix shift_tokens_right for JIT compilation + clarify one condition
    
    * fix decode
    
    * encoder => encode
    
    * simplify generate
    
    * add tests for encode and decode
    
    * style
    
    * add tests for cache
    
    * fix equivalence tests
    
    * sample generate now works with seq2seq
    
    * generation tests
    
    * initialize dense layers
    
    * docstring and cleanup
    
    * quality
    
    * remove get/set input_embeddings
    
    * address Patricks suggestions
    
    * decode for every model, remove encoder_outputs from call
    
    * update tests accordingly
    
    * decode returns only decoder outputs and logits
    
    * fix arguments
    
    * doc encode, decode methods
    
    * correct base_model_prefix
    
    * fix test for seq classif model
    
    * fix docs
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    4a51b1dd
test_modeling_flax_bart.py 16.7 KB