• Tom Aarsen's avatar
    Generate: New `Cache` abstraction and Attention Sinks support (#26681) · 633215ba
    Tom Aarsen authored
    * Draft version of new KV Caching
    
    This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
    / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented
    in a third-party or in transformers directly
    
    * Address numerous PR suggestions
    
    1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
    2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
    3. Remove __bool__ and __getitem__ magic as they're confusing.
    4. past_key_values.update(key, value, idx) now returns key, value.
    5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
    6. Separate key_cache and value_cache.
    
    Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.
    
    * Implement the SinkCache through backward+forward rotations
    
    * Integrate (Sink)Cache with Llama FA2
    
    * Set use_legacy_cache=True as default, allows for test passes
    
    * Move from/to_legacy_cache to ...Model class
    
    * Undo unnecessary newline change
    
    * Remove copy utility from deprecated OpenLlama
    
    * Match import style
    
    * manual rebase with main
    
    * Cache class working with generate (#1)
    
    * Draft version of new KV Caching
    
    This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks)
    / StreamingLLM (https://arxiv.org/abs/2309.17453
    
    ) to be easily implemented
    in a third-party or in transformers directly
    
    * Address numerous PR suggestions
    
    1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic.
    2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls.
    3. Remove __bool__ and __getitem__ magic as they're confusing.
    4. past_key_values.update(key, value, idx) now returns key, value.
    5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR.
    6. Separate key_cache and value_cache.
    
    Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method.
    
    * Integrate (Sink)Cache with Llama FA2
    
    * Move from/to_legacy_cache to ...Model class
    
    * Undo unnecessary newline change
    
    * Match import style
    
    * working generate
    
    * Add tests; Simplify code; Apply changes to Mistral and Persimmon
    
    * fix rebase mess
    
    * a few more manual fixes
    
    * last manual fix
    
    * propagate changes to phi
    
    * upgrade test
    
    * add use_legacy_cache docstring; beef up tests
    
    * reintroduce unwanted deletes
    
    ---------
    Co-authored-by: default avatarTom Aarsen <Cubiegamedev@gmail.com>
    
    * move import
    
    * add default to model_kwargs.get('use_legacy_cache')
    
    * correct failing test
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * apply PR suggestions
    
    * fix failing test
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    Co-authored-by: default avatarTom Aarsen <37621491+tomaarsen@users.noreply.github.com>
    
    * PR comments
    
    * tmp commit
    
    * add docstrings
    
    * more tests, more docstrings, add to docs
    
    * derp
    
    * tmp commit
    
    * tmp dbg
    
    * more dbg
    
    * fix beam search bug
    
    * cache can be a list of tuples in some models
    
    * fix group beam search
    
    * all but sinkcache integration tests
    
    * fix sink cache and add hard integration test
    
    * now also compatible with input_embeds input
    
    * PR comments
    
    * add Cache support to Phi+FA2
    
    * make fixup
    
    ---------
    Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
    Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    633215ba
test_utils.py 143 KB