• Pedro Cuenca's avatar
    Flax support for Stable Diffusion 2 (#1423) · 4d1e4e24
    Pedro Cuenca authored
    
    
    * Flax: start adapting to Stable Diffusion 2
    
    * More changes.
    
    * attention_head_dim can be a tuple.
    
    * Fix typos
    
    * Add simple SD 2 integration test.
    
    Slice values taken from my Ampere GPU.
    
    * Add simple UNet integration tests for Flax.
    
    Note that the expected values are taken from the PyTorch results. This
    ensures the Flax and PyTorch versions are not too far off.
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * Typos and style
    
    * Tests: verify jax is available.
    
    * Style
    
    * Make flake happy
    
    * Remove typo.
    
    * Simple Flax SD 2 pipeline tests.
    
    * Import order
    
    * Remove unused import.
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    Co-authored-by: @camenduru 
    4d1e4e24
attention_flax.py 11 KB