• Sayak Paul's avatar
    [feat] allow SDXL pipeline to run with fused QKV projections (#6030) · a2bc2e14
    Sayak Paul authored
    
    
    * debug
    
    * from step
    
    * print
    
    * turn sigma a list
    
    * make str
    
    * init_noise_sigma
    
    * comment
    
    * remove prints
    
    * feat: introduce fused projections
    
    * change to a better name
    
    * no grad
    
    * device.
    
    * device
    
    * dtype
    
    * okay
    
    * print
    
    * more print
    
    * fix: unbind -> split
    
    * fix: qkv >-> k
    
    * enable disable
    
    * apply attention processor within the method
    
    * attn processors
    
    * _enable_fused_qkv_projections
    
    * remove print
    
    * add fused projection to vae
    
    * add todos.
    
    * add: documentation and cleanups.
    
    * add: test for qkv projection fusion.
    
    * relax assertions.
    
    * relax further
    
    * fix: docs
    
    * fix-copies
    
    * correct error message.
    
    * Empty-Commit
    
    * better conditioning on disable_fused_qkv_projections
    
    * check
    
    * check processor
    
    * bfloat16 computation.
    
    * check latent dtype
    
    * style
    
    * remove copy temporarily
    
    * cast latent to bfloat16
    
    * fix: vae -> self.vae
    
    * remove print.
    
    * add _change_to_group_norm_32
    
    * comment out stuff that didn't work
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * reflect patrick's suggestions.
    
    * fix imports
    
    * fix: disable call.
    
    * fix more
    
    * fix device and dtype
    
    * fix conditions.
    
    * fix more
    
    * Apply suggestions from code review
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    ---------
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    a2bc2e14
attention_processor.py 99.2 KB