1. 20 Sep, 2022 2 commits
    • Mishig Davaadorj's avatar
      FlaxDiffusionPipeline & FlaxStableDiffusionPipeline (#559) · d934d3d7
      Mishig Davaadorj authored
      
      
      * WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
      
      * todo comment
      
      * Fix imports
      
      * Fix imports
      
      * add dummies
      
      * Fix empty init
      
      * make pipeline work
      
      * up
      
      * Use Flax schedulers (typing, docstring)
      
      * Wrap model imports inside availability checks.
      
      * more updates
      
      * make sure flax is not broken
      
      * make style
      
      * more fixes
      
      * up
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarPedro Cuenca <pedro@latenitesoft.com>
      d934d3d7
    • Younes Belkada's avatar
      Add `from_pt` argument in `.from_pretrained` (#527) · 0902449e
      Younes Belkada authored
      * first commit:
      
      - add `from_pt` argument in `from_pretrained` function
      - add `modeling_flax_pytorch_utils.py` file
      
      * small nit
      
      - fix a small nit - to not enter in the second if condition
      
      * major changes
      
      - modify FlaxUnet modules
      - first conversion script
      - more keys to be matched
      
      * keys match
      
      - now all keys match
      - change module names for correct matching
      - upsample module name changed
      
      * working v1
      
      - test pass with atol and rtol= `4e-02`
      
      * replace unsued arg
      
      * make quality
      
      * add small docstring
      
      * add more comments
      
      - add TODO for embedding layers
      
      * small change
      
      - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
      
      * add more conditions on conversion
      
      - add better test to check for keys conversion
      
      * make shapes consistent
      
      - output `img_w x img_h x n_channels` from the VAE
      
      * Revert "make shapes consistent"
      
      This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
      
      * fix unet shape
      
      - channels first!
      0902449e
  2. 19 Sep, 2022 2 commits
  3. 15 Sep, 2022 2 commits
    • Mishig Davaadorj's avatar
      Add `init_weights` method to `FlaxMixin` (#513) · fb5468a6
      Mishig Davaadorj authored
      
      
      * Add `init_weights` method to `FlaxMixin`
      
      * Rn `random_state` -> `shape_state`
      
      * `PRNGKey(0)` for `jax.eval_shape`
      
      * No allow mismatched sizes
      
      * Update src/diffusers/modeling_flax_utils.py
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Update src/diffusers/modeling_flax_utils.py
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * docstring diffusers
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      fb5468a6
    • Kashif Rasul's avatar
      Karras VE, DDIM and DDPM flax schedulers (#508) · b34be039
      Kashif Rasul authored
      * beta never changes removed from state
      
      * fix typos in docs
      
      * removed unused var
      
      * initial ddim flax scheduler
      
      * import
      
      * added dummy objects
      
      * fix style
      
      * fix typo
      
      * docs
      
      * fix typo in comment
      
      * set return type
      
      * added flax ddom
      
      * fix style
      
      * remake
      
      * pass PRNG key as argument and split before use
      
      * fix doc string
      
      * use config
      
      * added flax Karras VE scheduler
      
      * make style
      
      * fix dummy
      
      * fix ndarray type annotation
      
      * replace returns a new state
      
      * added lms_discrete scheduler
      
      * use self.config
      
      * add_noise needs state
      
      * use config
      
      * use config
      
      * docstring
      
      * added flax score sde ve
      
      * fix imports
      
      * fix typos
      b34be039
  4. 14 Sep, 2022 1 commit