1. 04 Oct, 2022 3 commits
  2. 30 Sep, 2022 3 commits
    • Nouamane Tazi's avatar
      Fix slow tests (#689) · b2cfc7a0
      Nouamane Tazi authored
      * revert using baddbmm in attention
      - to fix `test_stable_diffusion_memory_chunking` test
      
      * styling
      b2cfc7a0
    • Josh Achiam's avatar
      Allow resolutions that are not multiples of 64 (#505) · a784be2e
      Josh Achiam authored
      
      
      * Allow resolutions that are not multiples of 64
      
      * ran black
      
      * fix bug
      
      * add test
      
      * more explanation
      
      * more comments
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      a784be2e
    • Nouamane Tazi's avatar
      Optimize Stable Diffusion (#371) · 9ebaea54
      Nouamane Tazi authored
      * initial commit
      
      * make UNet stream capturable
      
      * try to fix noise_pred value
      
      * remove cuda graph and keep NB
      
      * non blocking unet with PNDMScheduler
      
      * make timesteps np arrays for pndm scheduler
      because lists don't get formatted to tensors in `self.set_format`
      
      * make max async in pndm
      
      * use channel last format in unet
      
      * avoid moving timesteps device in each unet call
      
      * avoid memcpy op in `get_timestep_embedding`
      
      * add `channels_last` kwarg to `DiffusionPipeline.from_pretrained`
      
      * update TODO
      
      * replace `channels_last` kwarg with `memory_format` for more generality
      
      * revert the channels_last changes to leave it for another PR
      
      * remove non_blocking when moving input ids to device
      
      * remove blocking from all .to() operations at beginning of pipeline
      
      * fix merging
      
      * fix merging
      
      * model can run in other precisions without autocast
      
      * attn refactoring
      
      * Revert "attn refactoring"
      
      This reverts commit 0c70c0e189cd2c4d8768274c9fcf5b940ee310fb.
      
      * remove restriction to run conv_norm in fp32
      
      * use `baddbmm` instead of `matmul`for better in attention for better perf
      
      * removing all reshapes to test perf
      
      * Revert "removing all reshapes to test perf"
      
      This reverts commit 006ccb8a8c6bc7eb7e512392e692a29d9b1553cd.
      
      * add shapes comments
      
      * hardcore whats needed for jitting
      
      * Revert "hardcore whats needed for jitting"
      
      This reverts commit 2fa9c698eae2890ac5f8e367ca80532ecf94df9a.
      
      * Revert "remove restriction to run conv_norm in fp32"
      
      This reverts commit cec592890c32da3d1b78d38b49e4307aedf459b9.
      
      * revert using baddmm in attention's forward
      
      * cleanup comment
      
      * remove restriction to run conv_norm in fp32. no quality loss was noticed
      
      This reverts commit cc9bc1339c998ebe9e7d733f910c6d72d9792213.
      
      * add more optimizations techniques to docs
      
      * Revert "add shapes comments"
      
      This reverts commit 31c58eadb8892f95478cdf05229adf678678c5f4.
      
      * apply suggestions
      
      * make quality
      
      * apply suggestions
      
      * styling
      
      * `scheduler.timesteps` are now arrays so we dont need .to()
      
      * remove useless .type()
      
      * use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms`
      
      * move scheduler timestamps to correct device if tensors
      
      * add device to `set_timesteps` in LMSD scheduler
      
      * `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it
      
      * quick fix
      
      * styling
      
      * remove kwargs from schedulers `set_timesteps`
      
      * revert to using max in K-LMS inpaint pipeline test
      
      * Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it"
      
      This reverts commit 00d5a51e5c20d8d445c8664407ef29608106d899.
      
      * move timesteps to correct device before loop in SD pipeline
      
      * apply previous fix to other SD pipelines
      
      * UNet now accepts tensor timesteps even on wrong device, to avoid errors
      - it shouldnt affect performance if timesteps are alrdy on correct device
      - it does slow down performance if they're on the wrong device
      
      * fix pipeline when timesteps are arrays with strides
      9ebaea54
  3. 29 Sep, 2022 1 commit
  4. 27 Sep, 2022 1 commit
  5. 23 Sep, 2022 1 commit
    • Younes Belkada's avatar
      Flax documentation (#589) · 8b0be935
      Younes Belkada authored
      
      
      * documenting `attention_flax.py` file
      
      * documenting `embeddings_flax.py`
      
      * documenting `unet_blocks_flax.py`
      
      * Add new objs to doc page
      
      * document `vae_flax.py`
      
      * Apply suggestions from code review
      
      * modify `unet_2d_condition_flax.py`
      
      * make style
      
      * Apply suggestions from code review
      
      * make style
      
      * Apply suggestions from code review
      
      * fix indent
      
      * fix typo
      
      * fix indent unet
      
      * Update src/diffusers/models/vae_flax.py
      
      * Apply suggestions from code review
      Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
      Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
      Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
      8b0be935
  6. 22 Sep, 2022 1 commit
    • Suraj Patil's avatar
      [UNet2DConditionModel] add gradient checkpointing (#461) · e7120bae
      Suraj Patil authored
      * add grad ckpt to downsample blocks
      
      * make it work
      
      * don't pass gradient_checkpointing to upsample block
      
      * add tests for UNet2DConditionModel
      
      * add test_gradient_checkpointing
      
      * add gradient_checkpointing for up and down blocks
      
      * add functions to enable and disable grad ckpt
      
      * remove the forward argument
      
      * better naming
      
      * make supports_gradient_checkpointing private
      e7120bae
  7. 21 Sep, 2022 1 commit
  8. 20 Sep, 2022 4 commits
    • Patrick von Platen's avatar
      [Flax] Fix unet and ddim scheduler (#594) · 2345481c
      Patrick von Platen authored
      * [Flax] Fix unet and ddim scheduler
      
      * correct
      
      * finish
      2345481c
    • Mishig Davaadorj's avatar
      FlaxDiffusionPipeline & FlaxStableDiffusionPipeline (#559) · d934d3d7
      Mishig Davaadorj authored
      
      
      * WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
      
      * todo comment
      
      * Fix imports
      
      * Fix imports
      
      * add dummies
      
      * Fix empty init
      
      * make pipeline work
      
      * up
      
      * Use Flax schedulers (typing, docstring)
      
      * Wrap model imports inside availability checks.
      
      * more updates
      
      * make sure flax is not broken
      
      * make style
      
      * more fixes
      
      * up
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarPedro Cuenca <pedro@latenitesoft.com>
      d934d3d7
    • Suraj Patil's avatar
      [FlaxAutoencoderKL] rename weights to align with PT (#584) · c01ec2d1
      Suraj Patil authored
      * rename weights to align with PT
      
      * DiagonalGaussianDistribution => FlaxDiagonalGaussianDistribution
      
      * fix name
      c01ec2d1
    • Younes Belkada's avatar
      Add `from_pt` argument in `.from_pretrained` (#527) · 0902449e
      Younes Belkada authored
      * first commit:
      
      - add `from_pt` argument in `from_pretrained` function
      - add `modeling_flax_pytorch_utils.py` file
      
      * small nit
      
      - fix a small nit - to not enter in the second if condition
      
      * major changes
      
      - modify FlaxUnet modules
      - first conversion script
      - more keys to be matched
      
      * keys match
      
      - now all keys match
      - change module names for correct matching
      - upsample module name changed
      
      * working v1
      
      - test pass with atol and rtol= `4e-02`
      
      * replace unsued arg
      
      * make quality
      
      * add small docstring
      
      * add more comments
      
      - add TODO for embedding layers
      
      * small change
      
      - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
      
      * add more conditions on conversion
      
      - add better test to check for keys conversion
      
      * make shapes consistent
      
      - output `img_w x img_h x n_channels` from the VAE
      
      * Revert "make shapes consistent"
      
      This reverts commit 4cad1aeb4aeb224402dad13c018a5d42e96267f6.
      
      * fix unet shape
      
      - channels first!
      0902449e
  9. 19 Sep, 2022 6 commits
  10. 18 Sep, 2022 1 commit
  11. 16 Sep, 2022 2 commits
  12. 15 Sep, 2022 2 commits
    • Pedro Cuenca's avatar
      UNet Flax with FlaxModelMixin (#502) · d8b0e4f4
      Pedro Cuenca authored
      
      
      * First UNet Flax modeling blocks.
      
      Mimic the structure of the PyTorch files.
      The model classes themselves need work, depending on what we do about
      configuration and initialization.
      
      * Remove FlaxUNet2DConfig class.
      
      * ignore_for_config non-config args.
      
      * Implement `FlaxModelMixin`
      
      * Use new mixins for Flax UNet.
      
      For some reason the configuration is not correctly applied; the
      signature of the `__init__` method does not contain all the parameters
      by the time it's inspected in `extract_init_dict`.
      
      * Import `FlaxUNet2DConditionModel` if flax is available.
      
      * Rm unused method `framework`
      
      * Update src/diffusers/modeling_flax_utils.py
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Indicate types in flax.struct.dataclass as pointed out by @mishig25
      Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
      
      * Fix typo in transformer block.
      
      * make style
      
      * some more changes
      
      * make style
      
      * Add comment
      
      * Update src/diffusers/modeling_flax_utils.py
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * Rm unneeded comment
      
      * Update docstrings
      
      * correct ignore kwargs
      
      * make style
      
      * Update docstring examples
      
      * Make style
      
      * Style: remove empty line.
      
      * Apply style (after upgrading black from pinned version)
      
      * Remove some commented code and unused imports.
      
      * Add init_weights (not yet in use until #513).
      
      * Trickle down deterministic to blocks.
      
      * Rename q, k, v according to the latest PyTorch version.
      
      Note that weights were exported with the old names, so we need to be
      careful.
      
      * Flax UNet docstrings, default props as in PyTorch.
      
      * Fix minor typos in PyTorch docstrings.
      
      * Use FlaxUNet2DConditionOutput as output from UNet.
      
      * make style
      Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
      Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      d8b0e4f4
    • Suraj Patil's avatar
      [UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks (#442) · d144c46a
      Suraj Patil authored
      * pass norm_num_groups to unet blocs and attention
      
      * fix UNet2DConditionModel
      
      * add norm_num_groups arg in vae
      
      * add tests
      
      * remove comment
      
      * Apply suggestions from code review
      d144c46a
  13. 14 Sep, 2022 2 commits
  14. 12 Sep, 2022 1 commit
    • Kashif Rasul's avatar
      update expected results of slow tests (#268) · f4781a0b
      Kashif Rasul authored
      
      
      * update expected results of slow tests
      
      * relax sum and mean tests
      
      * Print shapes when reporting exception
      
      * formatting
      
      * fix sentence
      
      * relax test_stable_diffusion_fast_ddim for gpu fp16
      
      * relax flakey tests on GPU
      
      * added comment on large tolerences
      
      * black
      
      * format
      
      * set scheduler seed
      
      * added generator
      
      * use np.isclose
      
      * set num_inference_steps to 50
      
      * fix dep. warning
      
      * update expected_slice
      
      * preprocess if image
      
      * updated expected results
      
      * updated expected from CI
      
      * pass generator to VAE
      
      * undo change back to orig
      
      * use orignal
      
      * revert back the expected on cpu
      
      * revert back values for CPU
      
      * more undo
      
      * update result after using gen
      
      * update mean
      
      * set generator for mps
      
      * update expected on CI server
      
      * undo
      
      * use new seed every time
      
      * cpu manual seed
      
      * reduce num_inference_steps
      
      * style
      
      * use generator for randn
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      f4781a0b
  15. 09 Sep, 2022 2 commits
  16. 08 Sep, 2022 3 commits
    • Patrick von Platen's avatar
      [Black] Update black (#433) · b2b3b1a8
      Patrick von Platen authored
      * Update black
      
      * update table
      b2b3b1a8
    • Kashif Rasul's avatar
      [Docs] Models (#416) · 5e6417e9
      Kashif Rasul authored
      
      
      * docs for attention
      
      * types for embeddings
      
      * unet2d docstrings
      
      * UNet2DConditionModel docstrings
      
      * fix typos
      
      * style and vq-vae docstrings
      
      * docstrings  for VAE
      
      * Update src/diffusers/models/unet_2d.py
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * make style
      
      * added inherits from sentence
      
      * docstring to forward
      
      * make style
      
      * Apply suggestions from code review
      Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
      
      * finish model docs
      
      * up
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
      5e6417e9
    • Pedro Cuenca's avatar
      Inference support for `mps` device (#355) · 5dda1735
      Pedro Cuenca authored
      * Initial support for mps in Stable Diffusion pipeline.
      
      * Initial "warmup" implementation when using mps.
      
      * Make some deterministic tests pass with mps.
      
      * Disable training tests when using mps.
      
      * SD: generate latents in CPU then move to device.
      
      This is especially important when using the mps device, because
      generators are not supported there. See for example
      https://github.com/pytorch/pytorch/issues/84288.
      
      In addition, the other pipelines seem to use the same approach: generate
      the random samples then move to the appropriate device.
      
      After this change, generating an image in MPS produces the same result
      as when using the CPU, if the same seed is used.
      
      * Remove prints.
      
      * Pass AutoencoderKL test_output_pretrained with mps.
      
      Sampling from `posterior` must be done in CPU.
      
      * Style
      
      * Do not use torch.long for log op in mps device.
      
      * Perform incompatible padding ops in CPU.
      
      UNet tests now pass.
      See https://github.com/pytorch/pytorch/issues/84535
      
      
      
      * Style: fix import order.
      
      * Remove unused symbols.
      
      * Remove MPSWarmupMixin, do not apply automatically.
      
      We do apply warmup in the tests, but not during normal use.
      This adopts some PR suggestions by @patrickvonplaten.
      
      * Add comment for mps fallback to CPU step.
      
      * Add README_mps.md for mps installation and use.
      
      * Apply `black` to modified files.
      
      * Restrict README_mps to SD, show measures in table.
      
      * Make PNDM indexing compatible with mps.
      
      Addresses #239.
      
      * Do not use float64 when using LDMScheduler.
      
      Fixes #358.
      
      * Fix typo identified by @patil-suraj
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Adapt example to new output style.
      
      * Restore 1:1 results reproducibility with CompVis.
      
      However, mps latents need to be generated in CPU because generators
      don't work in the mps device.
      
      * Move PyTorch nightly to requirements.
      
      * Adapt `test_scheduler_outputs_equivalence` ton MPS.
      
      * mps: skip training tests instead of ignoring silently.
      
      * Make VQModel tests pass on mps.
      
      * mps ddim tests: warmup, increase tolerance.
      
      * ScoreSdeVeScheduler indexing made mps compatible.
      
      * Make ldm pipeline tests pass using warmup.
      
      * Style
      
      * Simplify casting as suggested in PR.
      
      * Add Known Issues to readme.
      
      * `isort` import order.
      
      * Remove _mps_warmup helpers from ModelMixin.
      
      And just make changes to the tests.
      
      * Skip tests using unittest decorator for consistency.
      
      * Remove temporary var.
      
      * Remove spurious blank space.
      
      * Remove unused symbol.
      
      * Remove README_mps.
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> 
      5dda1735
  17. 07 Sep, 2022 1 commit
  18. 06 Sep, 2022 2 commits
  19. 05 Sep, 2022 3 commits