• Pedro Cuenca's avatar
    Flax memory efficient attention (#2889) · dc277501
    Pedro Cuenca authored
    
    
    * add use_memory_efficient params placeholder
    
    * test
    
    * add memory efficient attention jax
    
    * add memory efficient attention jax
    
    * newline
    
    * forgot dot
    
    * Rename use_memory_efficient
    
    * Keep dtype last.
    
    * Actually use key_chunk_size
    
    * Rename symbol
    
    * Apply style
    
    * Rename use_memory_efficient
    
    * Keep dtype last
    
    * Pass `use_memory_efficient_attention` in `from_pretrained`
    
    * Move JAX memory efficient attention to attention_flax.
    
    * Simple test.
    
    * style
    
    ---------
    Co-authored-by: default avatarmuhammad_hanif <muhammad_hanif@sofcograha.co.id>
    Co-authored-by: default avatarMuhHanif <48muhhanif@gmail.com>
    dc277501
attention_flax.py 17.3 KB