1. 06 Feb, 2024 1 commit
  2. 07 Dec, 2023 1 commit
    • Alex McKinney's avatar
      Add Llama Flax Implementation (#24587) · 75336c17
      Alex McKinney authored
      * Copies `modeling_flax_gpt_neo.py` to start
      
      * MLP Block. WIP Attention and Block
      
      * Adds Flax implementation of `LlamaMLP`
      Validated with in-file test.
      Some slight numeric differences, but assuming it isn't an issue
      
      * Adds `FlaxLlamaRMSNorm` layer
      `flax.linen` includes `RMSNorm` layer but not necessarily in all
      versions. Hence, we add in-file.
      
      * Adds FlaxLlamaAttention
      Copied from GPT-J as it has efficient caching implementation as well as
      rotary embeddings.
      Notice numerically different, but not by a huge amount. Needs
      investigating
      
      * Adds `FlaxLlamaDecoderLayer`
      numerically inaccurate, debugging..
      
      * debugging rotary mismatch
      gptj uses interleaved whilst llama uses contiguous
      i think they match now but still final result is wrong.
      maybe drop back to just debugging attention layer?
      
      * fixes bug with decoder layer
      still somewhat numerically inaccurate, but close enough for now
      
      * adds markers for what to implement next
      the structure here diverges a lot from the PT version.
      not a big fan of it, but just get something working for now
      
      * implements `FlaxLlamaBlockCollection`]
      tolerance must be higher than expected, kinda disconcerting
      
      * Adds `FlaxLlamaModule`
      equivalent PyTorch model is `LlamaModel`
      yay! a language model🤗
      
      * adds `FlaxLlamaForCausalLMModule`
      equivalent to `LlamaForCausalLM`
      still missing returning dict or tuple, will add later
      
      * start porting pretrained wrappers
      realised it probably needs return dict as a prereq
      
      * cleanup, quality, style
      
      * readds `return_dict` and model output named tuples
      
      * (tentatively) pretrained wrappers work 🔥
      
      * fixes numerical mismatch in `FlaxLlamaRMSNorm`
      seems `jax.lax.rsqrt` does not match `torch.sqrt`.
      manually computing `1 / jax.numpy.sqrt` results in matching values.
      
      * [WIP] debugging numerics
      
      * numerical match
      I think issue was accidental change of backend. forcing CPU fixes test.
      We expect some mismatch on GPU.
      
      * adds in model and integration tests for Flax Llama
      summary of failing:
      - mul invalid combination of dimensions
      - one numerical mismatch
      - bf16 conversion (maybe my local backend issue)
      - params are not FrozenDict
      
      * adds missing TYPE_CHECKING import and `make fixup`
      
      * adds back missing docstrings
      needs review on quality of docstrings, not sure what is required.
      Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO
      
      * commenting out equivalence test as can just use common
      
      * debugging
      
      * Fixes bug where mask and pos_ids were swapped in pretrained models
      This results in all tests passing now 🔥
      
      
      
      * cleanup of modeling file
      
      * cleanup of test file
      
      * Resolving simpler review comments
      
      * addresses more minor review comments
      
      * fixing introduced pytest errors from review
      
      * wip additional slow tests
      
      * wip tests
      need to grab a GPU machine to get real logits for comparison
      otherwise, slow tests should be okay
      
      * `make quality`, `make style`
      
      * adds slow integration tests
      - checking logits
      - checking hidden states
      - checking generation outputs
      
      * `make fix-copies`
      
      * fix mangled function following `make fix-copies`
      
      * adds missing type checking imports
      
      * fixes missing parameter checkpoint warning
      
      * more finegrained 'Copied from' tags
      avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`
      
      * swaps import guards
      ??? how did these get swapped initially?
      
      * removing `inv_freq` again as pytorch version has now removed
      
      * attempting to get CI to pass
      
      * adds doc entries for llama flax models
      
      * fixes typo in __init__.py imports
      
      * adds back special equivalence tests
      these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version
      
      * overrides tests with dummy to see if CI passes
      need to fill in these tests later
      
      * adds my contribution to docs
      
      * `make style; make quality`
      
      * replaces random masking with fixed to work with flax version
      
      * `make quality; make style`
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * updates `x`->`tensor` in `rotate_half`
      
      * addresses smaller review comments
      
      * Update docs/source/en/model_doc/llama.md
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * adds integration test class
      
      * adds `dtype` to rotary embedding to cast outputs
      
      * adds type to flax llama rotary layer
      
      * `make style`
      
      * `make fix-copies`
      
      * Apply suggestions from code review
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * applies suggestions from review
      
      * Update modeling_flax_llama.py
      
      * `make fix-copies`
      
      * Update tests/models/llama/test_modeling_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * Update src/transformers/models/llama/modeling_flax_llama.py
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      
      * fixes shape mismatch in FlaxLlamaMLP
      
      * applies some suggestions from reviews
      
      * casts attn output logits to f32 regardless of dtype
      
      * adds attn bias using `LlamaConfig.attention_bias`
      
      * adds Copied From comments to Flax Llama test
      
      * mistral and persimmon test change -copy from llama
      
      * updates docs index
      
      * removes Copied from in tests
      
      it was preventing `make fix-copies` from succeeding
      
      * quality and style
      
      * ignores FlaxLlama input docstring
      
      * adds revision to `_CHECKPOINT_FOR_DOC`
      
      * repo consistency and quality
      
      * removes unused import
      
      * removes copied from from Phi test
      
      now diverges from llama tests following FlaxLlama changes
      
      * adds `_REAL_CHECKPOINT_FOR_DOC`
      
      * removes refs from pr tests
      
      * reformat to make ruff happy
      
      ---------
      Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
      75336c17
  3. 03 Nov, 2023 1 commit
    • Maria Khalusova's avatar
      [Docs] Model_doc structure/clarity improvements (#26876) · 5964f820
      Maria Khalusova authored
      * first batch of structure improvements for model_docs
      
      * second batch of structure improvements for model_docs
      
      * more structure improvements for model_docs
      
      * more structure improvements for model_docs
      
      * structure improvements for cv model_docs
      
      * more structural refactoring
      
      * addressed feedback about image processors
      5964f820
  4. 05 Sep, 2023 1 commit
  5. 18 Jul, 2023 1 commit
  6. 20 Jun, 2023 1 commit
  7. 09 Jun, 2023 1 commit
  8. 07 May, 2023 1 commit
  9. 12 Apr, 2023 1 commit
  10. 06 Apr, 2023 1 commit
    • Nicolas Patry's avatar
      Adding Llama FastTokenizer support. (#22264) · 1670be4b
      Nicolas Patry authored
      * Adding Llama FastTokenizer support.
      
      - Requires https://github.com/huggingface/tokenizers/pull/1183 version
      - Only support byte_fallback for llama, raise otherwise (safety net).
      - Lots of questions are special tokens
      
      How to test:
      
      ```python
      
      from transformers.convert_slow_tokenizer import convert_slow_tokenizer
      from transformers import AutoTokenizer
      from tokenizers import Tokenizer
      
      tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
      
      if False:
          new_tokenizer = Tokenizer.from_file("tok.json")
      else:
          new_tokenizer = convert_slow_tokenizer(tokenizer)
          new_tokenizer.save("tok.json")
      
      strings = [
          "This is a test",
          "生活的真谛是",
          "生活的真谛是[MASK]。",
          # XXX: This one is problematic because of special tokens
          # "<s> Something something",
      ]
      
      for string in strings:
          encoded = tokenizer(string)["input_ids"]
          encoded2 = new_tokenizer.encode(string).ids
      
          assert encoded == encoded2, f"{encoded} != {encoded2}"
      
          decoded = tokenizer.decode(encoded)
          decoded2 = new_tokenizer.decode(encoded2)
      
          assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
      ```
      
      The converter + some test script.
      
      The test script.
      
      Tmp save.
      
      Adding Fast tokenizer + tests.
      
      Adding the tokenization tests.
      
      Correct combination.
      
      Small fix.
      
      Fixing tests.
      
      Fixing with latest update.
      
      Rebased.
      
      fix copies + normalized added tokens  + copies.
      
      Adding doc.
      
      TMP.
      
      Doc + split files.
      
      Doc.
      
      Versions + try import.
      
      Fix Camembert + warnings -> Error.
      
      Fix by ArthurZucker.
      
      Not a decorator.
      
      * Fixing comments.
      
      * Adding more to docstring.
      
      * Doc rewriting.
      1670be4b
  11. 03 Apr, 2023 1 commit
  12. 20 Mar, 2023 1 commit
  13. 17 Mar, 2023 3 commits
  14. 16 Mar, 2023 1 commit
    • Jason Phang's avatar
      LLaMA Implementation (#21955) · 0041be5b
      Jason Phang authored
      
      
      * LLaMA
      
      * sharding and docs
      
      * tweak
      
      * black
      
      * inits
      
      * ruff
      
      * LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP
      
      * init
      
      * no checkpoint
      
      * docs
      
      * ruff
      
      * type_vocab_size
      
      * tokenizer fixes
      
      * tokenizer fixes
      
      * Update tokenization_llama.py
      
      * Update tokenization_llama.py
      
      * Update configuration_llama.py
      
      * Update modeling_llama.py
      
      * tokenizer add_bos by default
      
      * licenses
      
      * remove decoder
      
      * norms and mlp
      
      * rope overhaul
      
      * tweaks
      
      * black
      
      * mention OPT implementation
      
      * off-by-one naming
      
      * typo
      
      * fix
      
      * tokenization fix and slicing bug
      
      * padding config
      
      * cleanup
      
      * black
      
      * update tests
      
      * undo typo
      
      * fix vocab caching logic
      
      * ruff
      
      * docbuilder
      
      * attn fix from BlackSamorez
      
      * initial feedback
      
      * typo
      
      * docs
      
      * llama case
      
      * llama case
      
      * load checkpoint docs
      
      * comment about tokenizer
      
      * tokenizer defaults
      
      * clear past_key_values if use_cache=False
      
      * last tweaks
      
      * last tweaks
      
      * last tweaks
      
      * last tweaks
      
      ---------
      Co-authored-by: default avatarStella Biderman <stellabiderman@gmail.com>
      0041be5b