• Pedro Cuenca's avatar
    Flax pipeline pndm (#583) · ab3fd671
    Pedro Cuenca authored
    
    
    * WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
    
    * todo comment
    
    * Fix imports
    
    * Fix imports
    
    * add dummies
    
    * Fix empty init
    
    * make pipeline work
    
    * up
    
    * Allow dtype to be overridden on model load.
    
    This may be a temporary solution until #567 is addressed.
    
    * Convert params to bfloat16 or fp16 after loading.
    
    This deals with the weights, not the model.
    
    * Use Flax schedulers (typing, docstring)
    
    * PNDM: replace control flow with jax functions.
    
    Otherwise jitting/parallelization don't work properly as they don't know
    how to deal with traced objects.
    
    I temporarily removed `step_prk`.
    
    * Pass latents shape to scheduler set_timesteps()
    
    PNDMScheduler uses it to reserve space, other schedulers will just
    ignore it.
    
    * Wrap model imports inside availability checks.
    
    * Optionally return state in from_config.
    
    Useful for Flax schedulers.
    
    * Do not convert model weights to dtype.
    
    * Re-enable PRK steps with functional implementation.
    
    Values returned still not verified for correctness.
    
    * Remove left over has_state var.
    
    * make style
    
    * Apply suggestion list -> tuple
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    
    * Apply suggestion list -> tuple
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    
    * Remove unused comments.
    
    * Use zeros instead of empty.
    Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
    Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
    ab3fd671
scheduling_pndm_flax.py 21.2 KB