• Kian Sierra McGettigan's avatar
    Flax mistral (#26943) · f7076cd3
    Kian Sierra McGettigan authored
    * direct copy from llama work
    
    * mistral modules forward pass working
    
    * flax mistral forward pass with sliding window
    
    * added tests
    
    * added layer collection approach
    
    * Revert "added layer collection approach"
    
    This reverts commit 0e2905bf2236ec323163fc1a9f0c016b21aa8b8f.
    
    * Revert "Revert "added layer collection approach""
    
    This reverts commit fb17b6187ac5d16da7c461e1130514dc3d137a43.
    
    * fixed attention outputs
    
    * added mistral to init and auto
    
    * fixed import name
    
    * fixed layernorm weight dtype
    
    * freeze initialized weights
    
    * make sure conversion consideres bfloat16
    
    * added backend
    
    * added docstrings
    
    * added cache
    
    * fixed sliding window causal mask
    
    * passes cache tests
    
    * passed all tests
    
    * applied make style
    
    * removed commented out code
    
    * applied fix-copies ignored other model changes
    
    * applied make fix-copies
    
    * removed unused functions
    
    * passed generation integration test
    
    * slow tests pass
    
    * fixed slow tests
    
    * changed default dtype from jax.numpy.float32 to float32 for docstring check
    
    * skip cache test  for FlaxMistralForSequenceClassification since if pad_token_id in input_ids it doesn't score previous input_ids
    
    * updated checkpoint since from_pt not included
    
    * applied black style
    
    * removed unused args
    
    * Applied styling and fixup
    
    * changed checkpoint for doc back
    
    * fixed rf after adding it to hf hub
    
    * Add dummy ckpt
    
    * applied styling
    
    * added tokenizer to new ckpt
    
    * fixed slice format
    
    * fix init and slice
    
    * changed ref for placeholder TODO
    
    * added copies from Llama
    
    * applied styling
    
    * applied fix-copies
    
    * fixed docs
    
    * update weight dtype reconversion for sharded weights
    
    * removed Nullable input ids
    
    * Removed unnecessary output attentions in Module
    
    * added embedding weight initialziation
    
    * removed unused past_key_values
    
    * fixed deterministic
    
    * Fixed RMS Norm and added copied from
    
    * removed input_embeds
    
    * applied make style
    
    * removed nullable input ids from sequence classification model
    
    * added copied from GPTJ
    
    * added copied from Llama on FlaxMistralDecoderLayer
    
    * added copied from to FlaxMistralPreTrainedModel methods
    
    * fix test deprecation warning
    
    * freeze gpt neox random_params and fix copies
    
    * applied make style
    
    * fixed doc issue
    
    * skipped docstring test to allign # copied from
    
    * applied make style
    
    * removed FlaxMistralForSequenceClassification
    
    * removed unused padding_idx
    
    * removed more sequence classification
    
    * removed sequence classification
    
    * applied styling and consistency
    
    * added copied from in tests
    
    * removed sequence classification test logic
    
    * applied styling
    
    * applied make style
    
    * removed freeze and fixed copies
    
    * undo test change
    
    * changed repeat_kv to tile
    
    * fixed to key value groups
    
    * updated copyright year
    
    * split casual_mask
    
    * empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest
    
    * went back to 2023 for tests_pr_documentation_tests
    
    * went back to 2024
    
    * changed tile to repeat
    
    * applied make style
    
    * empty for retry on Wav2Vec2
    f7076cd3
check_docstrings.py 40.1 KB