1. 25 Oct, 2022 1 commit
  2. 13 Oct, 2022 1 commit
    • Suraj Patil's avatar
      update flax scheduler API (#822) · 0a09af2f
      Suraj Patil authored
      * update flax scheduler API
      
      * remoev set format
      
      * fix call to scale_model_input
      
      * update flax pndm
      
      * use int32
      
      * update docstr
      0a09af2f
  3. 03 Oct, 2022 2 commits
    • Pedro Cuenca's avatar
      Fix import with Flax but without PyTorch (#688) · 688031c5
      Pedro Cuenca authored
      * Don't use `load_state_dict` if torch is not installed.
      
      * Define `SchedulerOutput` to use torch or flax arrays.
      
      * Don't import LMSDiscreteScheduler without torch.
      
      * Create distinct FlaxSchedulerOutput.
      
      * Additional changes required for FlaxSchedulerMixin
      
      * Do not import torch pipelines in Flax.
      
      * Revert "Define `SchedulerOutput` to use torch or flax arrays."
      
      This reverts commit f653140134b74d9ffec46d970eb46925fe3a409d.
      
      * Prefix Flax scheduler outputs for consistency.
      
      * make style
      
      * FlaxSchedulerOutput is now a dataclass.
      
      * Don't use f-string without placeholders.
      
      * Add blank line.
      
      * Style (docstrings)
      688031c5
    • Pedro Cuenca's avatar
      Flax: add shape argument to `set_timesteps` (#690) · 249b36cc
      Pedro Cuenca authored
      * Flax: add shape argument to set_timesteps
      
      * style
      249b36cc
  4. 29 Sep, 2022 1 commit
  5. 27 Sep, 2022 1 commit
    • Pedro Cuenca's avatar
      Flax pipeline pndm (#583) · ab3fd671
      Pedro Cuenca authored
      
      
      * WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
      
      * todo comment
      
      * Fix imports
      
      * Fix imports
      
      * add dummies
      
      * Fix empty init
      
      * make pipeline work
      
      * up
      
      * Allow dtype to be overridden on model load.
      
      This may be a temporary solution until #567 is addressed.
      
      * Convert params to bfloat16 or fp16 after loading.
      
      This deals with the weights, not the model.
      
      * Use Flax schedulers (typing, docstring)
      
      * PNDM: replace control flow with jax functions.
      
      Otherwise jitting/parallelization don't work properly as they don't know
      how to deal with traced objects.
      
      I temporarily removed `step_prk`.
      
      * Pass latents shape to scheduler set_timesteps()
      
      PNDMScheduler uses it to reserve space, other schedulers will just
      ignore it.
      
      * Wrap model imports inside availability checks.
      
      * Optionally return state in from_config.
      
      Useful for Flax schedulers.
      
      * Do not convert model weights to dtype.
      
      * Re-enable PRK steps with functional implementation.
      
      Values returned still not verified for correctness.
      
      * Remove left over has_state var.
      
      * make style
      
      * Apply suggestion list -> tuple
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Apply suggestion list -> tuple
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Remove unused comments.
      
      * Use zeros instead of empty.
      Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
      Co-authored-by: default avatarMishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      ab3fd671
  6. 21 Sep, 2022 1 commit
    • Pedro Cuenca's avatar
      Return Flax scheduler state (#601) · a9fdb3de
      Pedro Cuenca authored
      * Optionally return state in from_config.
      
      Useful for Flax schedulers.
      
      * has_state is now a property, make check more strict.
      
      I don't check the class is `SchedulerMixin` to prevent circular
      dependencies. It should be enough that the class name starts with "Flax"
      the object declares it "has_state" and the "create_state" exists too.
      
      * Use state in pipeline from_pretrained.
      
      * Make style
      a9fdb3de
  7. 20 Sep, 2022 1 commit
  8. 19 Sep, 2022 1 commit
  9. 15 Sep, 2022 1 commit
    • Kashif Rasul's avatar
      Karras VE, DDIM and DDPM flax schedulers (#508) · b34be039
      Kashif Rasul authored
      * beta never changes removed from state
      
      * fix typos in docs
      
      * removed unused var
      
      * initial ddim flax scheduler
      
      * import
      
      * added dummy objects
      
      * fix style
      
      * fix typo
      
      * docs
      
      * fix typo in comment
      
      * set return type
      
      * added flax ddom
      
      * fix style
      
      * remake
      
      * pass PRNG key as argument and split before use
      
      * fix doc string
      
      * use config
      
      * added flax Karras VE scheduler
      
      * make style
      
      * fix dummy
      
      * fix ndarray type annotation
      
      * replace returns a new state
      
      * added lms_discrete scheduler
      
      * use self.config
      
      * add_noise needs state
      
      * use config
      
      * use config
      
      * docstring
      
      * added flax score sde ve
      
      * fix imports
      
      * fix typos
      b34be039
  10. 13 Sep, 2022 3 commits
    • Kashif Rasul's avatar
      initial flax pndm schedular (#492) · 55f7ca3b
      Kashif Rasul authored
      * initial flax pndm
      
      * fix typo
      
      * use state
      
      * return state
      
      * add FlaxSchedulerOutput
      
      * fix style
      
      * add flax imports
      
      * make style
      
      * fix typos
      
      * return created state
      
      * make style
      
      * add torch/flax imports
      
      * docs
      
      * fixed typo
      
      * remove tensor_format
      
      * round instead of cast
      
      * ets is jnp array
      
      * remove copy
      55f7ca3b
    • Nathan Lambert's avatar
      Fix scheduler inference steps error with power of 3 (#466) · b56f1027
      Nathan Lambert authored
      * initial attempt at solving
      
      * fix pndm power of 3 inference_step
      
      * add power of 3 test
      
      * fix index in pndm test, remove ddim test
      
      * add comments, change to round()
      b56f1027
    • Nathan Lambert's avatar
      Scheduler docs update (#464) · da990633
      Nathan Lambert authored
      * update scheduler docs TODOs, fix typos
      
      * fix another typo
      da990633
  11. 08 Sep, 2022 3 commits
    • Patrick von Platen's avatar
      [Outputs] Improve syntax (#423) · f6fb3282
      Patrick von Platen authored
      
      
      * [Outputs] Improve syntax
      
      * improve more
      
      * fix docstring return
      
      * correct all
      
      * uP
      Co-authored-by: default avatarMishig Davaadorj <dmishig@gmail.com>
      f6fb3282
    • Pedro Cuenca's avatar
      Inference support for `mps` device (#355) · 5dda1735
      Pedro Cuenca authored
      * Initial support for mps in Stable Diffusion pipeline.
      
      * Initial "warmup" implementation when using mps.
      
      * Make some deterministic tests pass with mps.
      
      * Disable training tests when using mps.
      
      * SD: generate latents in CPU then move to device.
      
      This is especially important when using the mps device, because
      generators are not supported there. See for example
      https://github.com/pytorch/pytorch/issues/84288.
      
      In addition, the other pipelines seem to use the same approach: generate
      the random samples then move to the appropriate device.
      
      After this change, generating an image in MPS produces the same result
      as when using the CPU, if the same seed is used.
      
      * Remove prints.
      
      * Pass AutoencoderKL test_output_pretrained with mps.
      
      Sampling from `posterior` must be done in CPU.
      
      * Style
      
      * Do not use torch.long for log op in mps device.
      
      * Perform incompatible padding ops in CPU.
      
      UNet tests now pass.
      See https://github.com/pytorch/pytorch/issues/84535
      
      
      
      * Style: fix import order.
      
      * Remove unused symbols.
      
      * Remove MPSWarmupMixin, do not apply automatically.
      
      We do apply warmup in the tests, but not during normal use.
      This adopts some PR suggestions by @patrickvonplaten.
      
      * Add comment for mps fallback to CPU step.
      
      * Add README_mps.md for mps installation and use.
      
      * Apply `black` to modified files.
      
      * Restrict README_mps to SD, show measures in table.
      
      * Make PNDM indexing compatible with mps.
      
      Addresses #239.
      
      * Do not use float64 when using LDMScheduler.
      
      Fixes #358.
      
      * Fix typo identified by @patil-suraj
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      
      * Adapt example to new output style.
      
      * Restore 1:1 results reproducibility with CompVis.
      
      However, mps latents need to be generated in CPU because generators
      don't work in the mps device.
      
      * Move PyTorch nightly to requirements.
      
      * Adapt `test_scheduler_outputs_equivalence` ton MPS.
      
      * mps: skip training tests instead of ignoring silently.
      
      * Make VQModel tests pass on mps.
      
      * mps ddim tests: warmup, increase tolerance.
      
      * ScoreSdeVeScheduler indexing made mps compatible.
      
      * Make ldm pipeline tests pass using warmup.
      
      * Style
      
      * Simplify casting as suggested in PR.
      
      * Add Known Issues to readme.
      
      * `isort` import order.
      
      * Remove _mps_warmup helpers from ModelMixin.
      
      And just make changes to the tests.
      
      * Skip tests using unittest decorator for consistency.
      
      * Remove temporary var.
      
      * Remove spurious blank space.
      
      * Remove unused symbol.
      
      * Remove README_mps.
      Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
      Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> 
      5dda1735
    • Nathan Lambert's avatar
      [docs sprint] schedulers docs, will update (#376) · e6110f68
      Nathan Lambert authored
      
      
      * init schedulers docs
      
      * add some docstrings, fix sidebar formatting
      
      * add docstrings
      
      * [Type hint] PNDM schedulers (#335)
      
      * [Type hint] PNDM Schedulers
      
      * ran make style
      
      * updated timesteps type hint
      
      * apply suggestions from code review
      
      * ran make style
      
      * removed unused import
      
      * [Type hint] scheduling ddim (#343)
      
      * [Type hint] scheduling ddim
      
      * apply suggestions from code review
      
      apply suggestions to also return the return type
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      
      * make style
      
      * update class docstrings
      
      * add docstrings
      
      * missed merge edit
      
      * add general docs page
      
      * modify headings for right sidebar
      Co-authored-by: default avatarPartho <parthodas6176@gmail.com>
      Co-authored-by: default avatarSantiago Víquez <santi.viquez@gmail.com>
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      e6110f68
  12. 05 Sep, 2022 1 commit
  13. 04 Sep, 2022 1 commit
    • Partho's avatar
      [Type hint] PNDM schedulers (#335) · dea5ec50
      Partho authored
      * [Type hint] PNDM Schedulers
      
      * ran make style
      
      * updated timesteps type hint
      
      * apply suggestions from code review
      
      * ran make style
      
      * removed unused import
      dea5ec50
  14. 31 Aug, 2022 2 commits
  15. 30 Aug, 2022 1 commit
  16. 22 Aug, 2022 1 commit
  17. 16 Aug, 2022 2 commits
  18. 12 Aug, 2022 1 commit
  19. 11 Aug, 2022 2 commits
  20. 21 Jul, 2022 2 commits
  21. 20 Jul, 2022 2 commits
  22. 18 Jul, 2022 1 commit
    • Nathan Lambert's avatar
      VE/VP SDE updates (#90) · 63c68d97
      Nathan Lambert authored
      
      
      * improve comments for sde_ve scheduler, init tests
      
      * more comments, tweaking pipelines
      
      * timesteps --> num_training_timesteps, some comments
      
      * merge cpu test, add m1 data
      
      * fix scheduler tests with num_train_timesteps
      
      * make np compatible, add tests for sde ve
      
      * minor default variable fixes
      
      * make style and fix-copies
      Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
      63c68d97
  23. 15 Jul, 2022 3 commits
  24. 27 Jun, 2022 1 commit
  25. 22 Jun, 2022 1 commit
  26. 20 Jun, 2022 3 commits