"tests/models/auto/test_modeling_auto.py" did not exist on "c824d15aa1590ddb5d2fc977a8a1009a4b1d7262"
  • Alex McKinney's avatar
    Add Llama Flax Implementation (#24587) · 75336c17
    Alex McKinney authored
    * Copies `modeling_flax_gpt_neo.py` to start
    
    * MLP Block. WIP Attention and Block
    
    * Adds Flax implementation of `LlamaMLP`
    Validated with in-file test.
    Some slight numeric differences, but assuming it isn't an issue
    
    * Adds `FlaxLlamaRMSNorm` layer
    `flax.linen` includes `RMSNorm` layer but not necessarily in all
    versions. Hence, we add in-file.
    
    * Adds FlaxLlamaAttention
    Copied from GPT-J as it has efficient caching implementation as well as
    rotary embeddings.
    Notice numerically different, but not by a huge amount. Needs
    investigating
    
    * Adds `FlaxLlamaDecoderLayer`
    numerically inaccurate, debugging..
    
    * debugging rotary mismatch
    gptj uses interleaved whilst llama uses contiguous
    i think they match now but still final result is wrong.
    maybe drop back to just debugging attention layer?
    
    * fixes bug with decoder layer
    still somewhat numerically inaccurate, but close enough for now
    
    * adds markers for what to implement next
    the structure here diverges a lot from the PT version.
    not a big fan of it, but just get something working for now
    
    * implements `FlaxLlamaBlockCollection`]
    tolerance must be higher than expected, kinda disconcerting
    
    * Adds `FlaxLlamaModule`
    equivalent PyTorch model is `LlamaModel`
    yay! a language model馃
    
    * adds `FlaxLlamaForCausalLMModule`
    equivalent to `LlamaForCausalLM`
    still missing returning dict or tuple, will add later
    
    * start porting pretrained wrappers
    realised it probably needs return dict as a prereq
    
    * cleanup, quality, style
    
    * readds `return_dict` and model output named tuples
    
    * (tentatively) pretrained wrappers work 馃敟
    
    * fixes numerical mismatch in `FlaxLlamaRMSNorm`
    seems `jax.lax.rsqrt` does not match `torch.sqrt`.
    manually computing `1 / jax.numpy.sqrt` results in matching values.
    
    * [WIP] debugging numerics
    
    * numerical match
    I think issue was accidental change of backend. forcing CPU fixes test.
    We expect some mismatch on GPU.
    
    * adds in model and integration tests for Flax Llama
    summary of failing:
    - mul invalid combination of dimensions
    - one numerical mismatch
    - bf16 conversion (maybe my local backend issue)
    - params are not FrozenDict
    
    * adds missing TYPE_CHECKING import and `make fixup`
    
    * adds back missing docstrings
    needs review on quality of docstrings, not sure what is required.
    Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO
    
    * commenting out equivalence test as can just use common
    
    * debugging
    
    * Fixes bug where mask and pos_ids were swapped in pretrained models
    This results in all tests passing now 馃敟
    
    
    
    * cleanup of modeling file
    
    * cleanup of test file
    
    * Resolving simpler review comments
    
    * addresses more minor review comments
    
    * fixing introduced pytest errors from review
    
    * wip additional slow tests
    
    * wip tests
    need to grab a GPU machine to get real logits for comparison
    otherwise, slow tests should be okay
    
    * `make quality`, `make style`
    
    * adds slow integration tests
    - checking logits
    - checking hidden states
    - checking generation outputs
    
    * `make fix-copies`
    
    * fix mangled function following `make fix-copies`
    
    * adds missing type checking imports
    
    * fixes missing parameter checkpoint warning
    
    * more finegrained 'Copied from' tags
    avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`
    
    * swaps import guards
    ??? how did these get swapped initially?
    
    * removing `inv_freq` again as pytorch version has now removed
    
    * attempting to get CI to pass
    
    * adds doc entries for llama flax models
    
    * fixes typo in __init__.py imports
    
    * adds back special equivalence tests
    these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version
    
    * overrides tests with dummy to see if CI passes
    need to fill in these tests later
    
    * adds my contribution to docs
    
    * `make style; make quality`
    
    * replaces random masking with fixed to work with flax version
    
    * `make quality; make style`
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * updates `x`->`tensor` in `rotate_half`
    
    * addresses smaller review comments
    
    * Update docs/source/en/model_doc/llama.md
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * adds integration test class
    
    * adds `dtype` to rotary embedding to cast outputs
    
    * adds type to flax llama rotary layer
    
    * `make style`
    
    * `make fix-copies`
    
    * Apply suggestions from code review
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * applies suggestions from review
    
    * Update modeling_flax_llama.py
    
    * `make fix-copies`
    
    * Update tests/models/llama/test_modeling_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * Update src/transformers/models/llama/modeling_flax_llama.py
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    
    * fixes shape mismatch in FlaxLlamaMLP
    
    * applies some suggestions from reviews
    
    * casts attn output logits to f32 regardless of dtype
    
    * adds attn bias using `LlamaConfig.attention_bias`
    
    * adds Copied From comments to Flax Llama test
    
    * mistral and persimmon test change -copy from llama
    
    * updates docs index
    
    * removes Copied from in tests
    
    it was preventing `make fix-copies` from succeeding
    
    * quality and style
    
    * ignores FlaxLlama input docstring
    
    * adds revision to `_CHECKPOINT_FOR_DOC`
    
    * repo consistency and quality
    
    * removes unused import
    
    * removes copied from from Phi test
    
    now diverges from llama tests following FlaxLlama changes
    
    * adds `_REAL_CHECKPOINT_FOR_DOC`
    
    * removes refs from pr tests
    
    * reformat to make ruff happy
    
    ---------
    Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
    75336c17
test_modeling_phi.py 15.9 KB