• Pedro Cuenca's avatar
    SDXL flax (#4254) · 3651b14c
    Pedro Cuenca authored
    
    
    * support transformer_layers_per block in flax UNet
    
    * add support for text_time additional embeddings to Flax UNet
    
    * rename attention layers for VAE
    
    * add shape asserts when renaming attention layers
    
    * transpose VAE attention layers
    
    * add pipeline flax SDXL code [WIP]
    
    * continue add pipeline flax SDXL code [WIP]
    
    * cleanup
    
    * Working on JIT support
    
    Fixed prompt embedding shapes so they work in parallel mode. Assuming we
    always have both text encoders for now, for simplicity.
    
    * Fixing embeddings (untested)
    
    * Remove spurious line
    
    * Shard guidance_scale when jitting.
    
    * Decode images
    
    * Fix sharding
    
    * style
    
    * Refiner UNet can be loaded.
    
    * Refiner / img2img pipeline
    
    * Allow latent outputs from base and latent inputs in refiner
    
    This makes it possible to chain base + refiner without having to use the
    vae decoder in the base model, the vae encoder in the refiner, skipping
    conversions to/from PIL, and avoiding TPU <-> CPU memory copies.
    
    * Adapt to FlaxCLIPTextModelOutput
    
    * Update Flax XL pipeline to FlaxCLIPTextModelOutput
    
    * make fix-copies
    
    * make style
    
    * add euler scheduler
    
    * Fix import
    
    * Fix copies, comment unused code.
    
    * Fix SDXL Flax imports
    
    * Fix euler discrete begin
    
    * improve init import
    
    * finish
    
    * put discrete euler in init
    
    * fix flax euler
    
    * Fix more
    
    * make style
    
    * correct init
    
    * correct init
    
    * Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline
    
    * correct pipelines
    
    * finish
    
    ---------
    Co-authored-by: default avatarMartin Müller <martin.muller.me@gmail.com>
    Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    3651b14c
__init__.py 25.4 KB